diff --git a/rebar.lock b/rebar.lock new file mode 100644 index 0000000..57afcca --- /dev/null +++ b/rebar.lock @@ -0,0 +1 @@ +[]. diff --git a/src/advanced_ai_strategy.erl b/src/advanced_ai_strategy.erl index 7c37645..075d08d 100644 --- a/src/advanced_ai_strategy.erl +++ b/src/advanced_ai_strategy.erl @@ -10,6 +10,44 @@ game_history = [] % 游戏历史 }). +%% 初始化函数 +init_strategy_model() -> + #{ + parameters => #{ + risk_factor => 0.5, + aggressive_factor => 0.5, + defensive_factor => 0.5 + }, + history => [] + }. + +init_situation_model() -> + #{ + analysis_weights => #{ + hand_strength => 0.3, + control_level => 0.3, + tempo => 0.2, + position => 0.2 + }, + historical_data => [] + }. + +init_learning_model() -> + #{ + learning_rate => 0.01, + discount_factor => 0.9, + exploration_rate => 0.1, + model_weights => #{}, + experience_buffer => [] + }. + +init_pattern_database() -> + #{ + basic_patterns => init_basic_patterns(), + complex_patterns => init_complex_patterns(), + pattern_weights => init_pattern_weights() + }. + %% 高级策略初始化 init_strategy() -> #advanced_ai_state{ @@ -104,6 +142,109 @@ analyze_card_patterns(State, GameState) -> combo_opportunities => identify_combo_opportunities(CurrentHand, PatternDB) }. +%% 初始化基本牌型 +init_basic_patterns() -> + #{ + singles => [], + pairs => [], + triples => [], + sequences => [], + bombs => [] + }. + +%% 初始化复杂牌型 +init_complex_patterns() -> + #{ + airplane => [], + four_with_two => [], + three_with_one => [], + double_sequence => [] + }. + +%% 初始化牌型权重 +init_pattern_weights() -> + #{ + bomb => 1.0, + sequence => 0.8, + triple => 0.6, + pair => 0.4, + single => 0.2 + }. + +%% 生成可能的行动 +generate_possible_actions(GameState) -> + Cards = get_current_hand(GameState), + LastPlay = get_last_play(GameState), + generate_valid_plays(Cards, LastPlay). + +%% 生成有效的出牌选择 +generate_valid_plays(Cards, LastPlay) -> + case LastPlay of + [] -> generate_all_plays(Cards); + _ -> generate_greater_plays(Cards, LastPlay) + end. + +%% 获取当前手牌 +get_current_hand(GameState) -> + maps:get(hand_cards, GameState, []). + +%% 获取上一手牌 +get_last_play(GameState) -> + maps:get(last_play, GameState, []). + +%% 获取对手列表 +get_opponents(GameState) -> + maps:get(opponents, GameState, []). + +%% 分析单个对手 +analyze_single_opponent(Model, Opponent, GameState) -> + #{ + play_style => analyze_play_style(Model, Opponent), + remaining_cards => estimate_remaining_cards(Model, GameState), + threat_level => calculate_threat_level(Model, GameState) + }. + +%% 计算胜率 +calculate_win_probability(State, BaseAnalysis, OpponentAnalysis) -> + HandStrength = maps:get(hand_strength, BaseAnalysis), + ControlLevel = maps:get(control_level, BaseAnalysis), + ThreatLevel = calculate_average_threat(OpponentAnalysis), + + BaseProb = (HandStrength * 0.4) + (ControlLevel * 0.3) + ((1 - ThreatLevel) * 0.3), + adjust_probability(BaseProb, State). + +%% 计算平均威胁度 +calculate_average_threat(OpponentAnalysis) -> + TotalThreat = lists:foldl( + fun(Analysis, Acc) -> + Acc + maps:get(threat_level, Analysis, 0.0) + end, + 0.0, + OpponentAnalysis + ), + length(OpponentAnalysis) > 0 andalso TotalThreat / length(OpponentAnalysis). + +%% 调整概率 +adjust_probability(BaseProb, State) -> + StrategyModel = State#advanced_ai_state.strategy_model, + Adjustment = calculate_strategy_adjustment(StrategyModel), + max(0.0, min(1.0, BaseProb + Adjustment)). + +%% 计算策略调整 +calculate_strategy_adjustment(StrategyModel) -> + Parameters = maps:get(parameters, StrategyModel, #{}), + RiskFactor = maps:get(risk_factor, Parameters, 0.5), + (RiskFactor - 0.5) * 0.2. + +%% 选择最佳行动 +select_best_action(RefinedActions, SituationAnalysis) -> + case RefinedActions of + [] -> pass; + Actions -> + {BestAction, _Score} = lists:max(Actions), + BestAction + end. + %% 蒙特卡洛树搜索 monte_carlo_tree_search(State, Actions, GameState) -> MaxIterations = 1000, diff --git a/src/advanced_learning.erl b/src/advanced_learning.erl deleted file mode 100644 index 8716185..0000000 --- a/src/advanced_learning.erl +++ /dev/null @@ -1,73 +0,0 @@ --module(advanced_learning). --export([init/0, train/2, predict/2, update/3]). - --record(learning_state, { - neural_network, % 深度神经网络模型 - experience_buffer, % 经验回放缓冲 - model_version, % 模型版本 - training_stats % 训练统计 -}). - -%% 神经网络配置 --define(NETWORK_CONFIG, [ - {input_layer, 512}, - {hidden_layer_1, 256}, - {hidden_layer_2, 128}, - {hidden_layer_3, 64}, - {output_layer, 32} -]). - -%% 训练参数 --define(LEARNING_RATE, 0.001). --define(BATCH_SIZE, 64). --define(EXPERIENCE_BUFFER_SIZE, 10000). - -init() -> - Network = initialize_neural_network(?NETWORK_CONFIG), - #learning_state{ - neural_network = Network, - experience_buffer = queue:new(), - model_version = 1, - training_stats = #{ - total_games = 0, - win_rate = 0.0, - avg_reward = 0.0 - } - }. - -train(State, TrainingData) -> - % 准备批次数据 - Batch = prepare_batch(TrainingData, ?BATCH_SIZE), - - % 执行深度学习训练 - {UpdatedNetwork, Loss} = train_network(State#learning_state.neural_network, Batch), - - % 更新统计信息 - NewStats = update_training_stats(State#learning_state.training_stats, Loss), - - % 返回更新后的状态 - State#learning_state{ - neural_network = UpdatedNetwork, - training_stats = NewStats - }. - -predict(State, Input) -> - % 使用神经网络进行预测 - Features = extract_features(Input), - Prediction = neural_network:forward(State#learning_state.neural_network, Features), - process_prediction(Prediction). - -update(State, Experience, Reward) -> - % 更新经验缓冲 - NewBuffer = update_experience_buffer(State#learning_state.experience_buffer, - Experience, - Reward), - - % 判断是否需要进行训练 - case should_train(NewBuffer) of - true -> - TrainingData = prepare_training_data(NewBuffer), - train(State#learning_state{experience_buffer = NewBuffer}, TrainingData); - false -> - State#learning_state{experience_buffer = NewBuffer} - end. \ No newline at end of file diff --git a/src/ai_core.erl b/src/ai_core.erl index 6ff0b4a..cc606ae 100644 --- a/src/ai_core.erl +++ b/src/ai_core.erl @@ -1,176 +1,657 @@ -module(ai_core). --export([init_ai/1, make_decision/2, update_strategy/3]). + +-export([ + init/1, + make_decision/2, + evaluate_hand/1, + predict_plays/2, + update_knowledge/2, + calculate_win_rate/1 +]). + +-include("card_types.hrl"). -record(ai_state, { - personality, % aggressive | conservative | balanced - strategy_weights, % 策略权重 - knowledge_base, % 知识库 - game_history = [] % 游戏历史 + role, % dizhu | nongmin + hand_cards = [], % 当前手牌 + known_cards = [], % 已知的牌 + played_cards = [], % 已打出的牌 + player_history = [], % 玩家出牌历史 + game_stage, % early_game | mid_game | end_game + strategy_cache = #{} % 策略缓存 }). -%% AI初始化 -init_ai(Personality) -> +-record(play_context, { + last_play, % 上一手牌 + cards_remaining, % 剩余手牌数 + num_greater_cards, % 比上家大的牌数量 + control_factor, % 控制因子 + must_play = false % 是否必须出牌 +}). + +%% 初始化AI状态 +init(Role) -> #ai_state{ - personality = Personality, - strategy_weights = init_weights(Personality), - knowledge_base = init_knowledge_base() + role = Role, + game_stage = early_game }. -%% 决策制定 +%% 做出决策 make_decision(AIState, GameState) -> - % 分析当前局势 - Situation = analyze_situation(GameState), - - % 生成可能的行动 - PossiblePlays = generate_possible_plays(GameState), - - % 评估每个行动 - RatedPlays = evaluate_plays(PossiblePlays, AIState, Situation), + Context = analyze_context(AIState, GameState), + case should_play(Context, AIState) of + true -> + select_best_play(Context, AIState); + false -> + {pass, AIState} + end. + +%% 评估手牌 +evaluate_hand(Cards) -> + Components = analyze_components(Cards), + calculate_hand_value(Components). + +%% 预测可能的出牌 +predict_plays(Cards, LastPlay) -> + ValidPlays = generate_valid_plays(Cards, LastPlay), + score_potential_plays(ValidPlays, LastPlay). + +%% 更新知识库 +update_knowledge(AIState, Event) -> + update_state_with_event(AIState, Event). + +%% 计算胜率 +calculate_win_rate(AIState) -> + HandStrength = evaluate_hand(AIState#ai_state.hand_cards), + PositionValue = evaluate_position(AIState), + ControlValue = evaluate_control(AIState), - % 选择最佳行动 - select_best_play(RatedPlays, AIState). + calculate_probability(HandStrength, PositionValue, ControlValue). -%% 策略更新 -update_strategy(AIState, GameResult, GameHistory) -> - NewWeights = adjust_weights(AIState#ai_state.strategy_weights, GameResult), - NewKnowledge = update_knowledge(AIState#ai_state.knowledge_base, GameHistory), - AIState#ai_state{ - strategy_weights = NewWeights, - knowledge_base = NewKnowledge, - game_history = [GameHistory | AIState#ai_state.game_history] +%% 内部函数 + +%% 分析上下文 +analyze_context(AIState, GameState) -> + LastPlay = get_last_play(GameState), + CardsRemaining = length(AIState#ai_state.hand_cards), + GreaterCards = count_greater_cards(AIState#ai_state.hand_cards, LastPlay), + ControlFactor = calculate_control_factor(AIState, GameState), + + #play_context{ + last_play = LastPlay, + cards_remaining = CardsRemaining, + num_greater_cards = GreaterCards, + control_factor = ControlFactor, + must_play = must_play(AIState, GameState) }. -%% 内部函数 +%% 判断是否应该出牌 +should_play(Context, AIState) -> + case Context#play_context.must_play of + true -> true; + false -> + case Context#play_context.last_play of + none -> true; + _ -> + should_beat_last_play(Context, AIState) + end + end. -init_weights(aggressive) -> - #{ - control_weight => 0.8, - attack_weight => 0.7, - defense_weight => 0.3, - risk_weight => 0.6 - }; -init_weights(conservative) -> - #{ - control_weight => 0.5, - attack_weight => 0.4, - defense_weight => 0.8, - risk_weight => 0.3 - }; -init_weights(balanced) -> +%% 选择最佳出牌 +select_best_play(Context, AIState) -> + Candidates = generate_candidates(AIState#ai_state.hand_cards, Context), + ScoredPlays = [ + {score_play(Play, Context, AIState), Play} + || Play <- Candidates + ], + select_highest_scored_play(ScoredPlays, AIState). + +%% 分析牌型组件 +analyze_components(Cards) -> + GroupedCards = group_cards_by_value(Cards), #{ - control_weight => 0.6, - attack_weight => 0.6, - defense_weight => 0.6, - risk_weight => 0.5 + singles => find_singles(GroupedCards), + pairs => find_pairs(GroupedCards), + triples => find_triples(GroupedCards), + sequences => find_sequences(GroupedCards), + bombs => find_bombs(GroupedCards) }. -analyze_situation(GameState) -> - #{ - hand_strength => evaluate_hand_strength(GameState), - control_status => evaluate_control(GameState), - opponent_cards => estimate_opponent_cards(GameState), - game_stage => determine_game_stage(GameState) +%% 计算手牌价值 +calculate_hand_value(Components) -> + SinglesValue = calculate_singles_value(maps:get(singles, Components, [])), + PairsValue = calculate_pairs_value(maps:get(pairs, Components, [])), + TriplesValue = calculate_triples_value(maps:get(triples, Components, [])), + SequencesValue = calculate_sequences_value(maps:get(sequences, Components, [])), + BombsValue = calculate_bombs_value(maps:get(bombs, Components, [])), + + SinglesValue + PairsValue + TriplesValue + SequencesValue + BombsValue. + +%% 生成有效出牌选择 +generate_valid_plays(Cards, LastPlay) -> + case LastPlay of + none -> + generate_leading_plays(Cards); + Play -> + generate_following_plays(Cards, Play) + end. + +%% 评分潜在出牌 +score_potential_plays(Plays, LastPlay) -> + [{Play, score_play_potential(Play, LastPlay)} || Play <- Plays]. + +%% 更新状态 +update_state_with_event(AIState, {play_cards, Player, Cards}) -> + NewPlayed = AIState#ai_state.played_cards ++ Cards, + NewHistory = [{Player, Cards} | AIState#ai_state.player_history], + AIState#ai_state{ + played_cards = NewPlayed, + player_history = NewHistory + }; +update_state_with_event(AIState, {game_stage, NewStage}) -> + AIState#ai_state{game_stage = NewStage}; +update_state_with_event(AIState, _) -> + AIState. + +%% 计算胜率概率 +calculate_probability(HandStrength, PositionValue, ControlValue) -> + BaseProb = HandStrength * 0.5 + PositionValue * 0.3 + ControlValue * 0.2, + normalize_probability(BaseProb). + +%% 计算控制因子 +calculate_control_factor(AIState, GameState) -> + ControlCards = count_control_cards(AIState#ai_state.hand_cards), + TotalCards = count_total_remaining_cards(GameState), + RemainingCards = length(AIState#ai_state.hand_cards), + + ControlRatio = ControlCards / max(1, RemainingCards), + PositionBonus = calculate_position_bonus(AIState, GameState), + + ControlRatio * PositionBonus. + +%% 判断是否必须出牌 +must_play(AIState, GameState) -> + is_current_player(AIState, GameState) andalso + not has_active_play(GameState). + +%% 评估是否应该大过上家 +should_beat_last_play(Context, AIState) -> + case Context#play_context.last_play of + none -> true; + LastPlay -> + HandStrength = evaluate_hand(AIState#ai_state.hand_cards), + ControlLevel = Context#play_context.control_factor, + CardsLeft = Context#play_context.cards_remaining, + + should_beat(HandStrength, ControlLevel, CardsLeft, LastPlay) + end. + +%% 生成候选出牌 +generate_candidates(Cards, Context) -> + BasePlays = case Context#play_context.last_play of + none -> generate_leading_plays(Cards); + LastPlay -> generate_following_plays(Cards, LastPlay) + end, + + filter_candidates(BasePlays, Context). + +%% 对出牌进行评分 +score_play(Play, Context, AIState) -> + BaseScore = calculate_base_score(Play), + TempoScore = calculate_tempo_score(Play, Context), + ControlScore = calculate_control_score(Play, Context, AIState), + EfficiencyScore = calculate_efficiency_score(Play, Context), + + FinalScore = BaseScore * 0.4 + + TempoScore * 0.2 + + ControlScore * 0.3 + + EfficiencyScore * 0.1, + + adjust_score_for_context(FinalScore, Play, Context, AIState). + +%% 选择得分最高的出牌 +select_highest_scored_play(ScoredPlays, AIState) -> + case lists:sort(fun({Score1, _}, {Score2, _}) -> + Score1 >= Score2 + end, ScoredPlays) of + [{Score, Play}|_] when Score > 0 -> + {play, Play, update_after_play(Play, AIState)}; + _ -> + {pass, AIState} + end. + +%% 计算单牌价值 +calculate_singles_value(Singles) -> + lists:sum([ + case Value of + V when V >= ?CARD_2 -> Value * 1.5; + _ -> Value + end || {Value, _} <- Singles + ]). + +%% 计算对子价值 +calculate_pairs_value(Pairs) -> + lists:sum([Value * 2.2 || {Value, _} <- Pairs]). + +%% 计算三张价值 +calculate_triples_value(Triples) -> + lists:sum([Value * 3.5 || {Value, _} <- Triples]). + +%% 计算顺子价值 +calculate_sequences_value(Sequences) -> + lists:sum([ + Value * length(Cards) * 1.8 + || {Value, Cards} <- Sequences + ]). + +%% 计算炸弹价值 +calculate_bombs_value(Bombs) -> + lists:sum([Value * 10.0 || {Value, _} <- Bombs]). + +%% 生成首出牌型 +generate_leading_plays(Cards) -> + Components = analyze_components(Cards), + Singles = generate_single_plays(Components), + Pairs = generate_pair_plays(Components), + Triples = generate_triple_plays(Components), + Sequences = generate_sequence_plays(Components), + Bombs = generate_bomb_plays(Components), + + Singles ++ Pairs ++ Triples ++ Sequences ++ Bombs. + +%% 生成跟牌选择 +generate_following_plays(Cards, {Type, Value, _} = LastPlay) -> + ValidPlays = find_greater_plays(Cards, Type, Value), + BombPlays = find_bomb_plays(Cards), + RocketPlay = find_rocket_play(Cards), + + ValidPlays ++ BombPlays ++ RocketPlay. + +%% 计算基础分数 +calculate_base_score({Type, Value, Cards}) -> + BaseValue = Value * length(Cards), + TypeMultiplier = case Type of + ?CARD_TYPE_ROCKET -> 100.0; + ?CARD_TYPE_BOMB -> 80.0; + ?CARD_TYPE_STRAIGHT -> 40.0; + ?CARD_TYPE_STRAIGHT_PAIR -> 35.0; + ?CARD_TYPE_PLANE -> 30.0; + ?CARD_TYPE_THREE_TWO -> 25.0; + ?CARD_TYPE_THREE_ONE -> 20.0; + ?CARD_TYPE_THREE -> 15.0; + ?CARD_TYPE_PAIR -> 10.0; + ?CARD_TYPE_SINGLE -> 5.0 + end, + BaseValue * TypeMultiplier / 100.0. + +%% 计算节奏分数 +calculate_tempo_score(Play, Context) -> + case Context#play_context.game_stage of + early_game -> calculate_early_tempo(Play, Context); + mid_game -> calculate_mid_tempo(Play, Context); + end_game -> calculate_end_tempo(Play, Context) + end. + +%% 计算控制分数 +calculate_control_score(Play, Context, AIState) -> + {Type, Value, _} = Play, + BaseControl = case Type of + ?CARD_TYPE_BOMB -> 1.0; + ?CARD_TYPE_ROCKET -> 1.0; + _ when Value >= ?CARD_2 -> 0.8; + _ -> 0.5 + end, + BaseControl * Context#play_context.control_factor. + +%% 计算效率分数 +calculate_efficiency_score(Play, Context) -> + {_, _, Cards} = Play, + CardsUsed = length(Cards), + RemainingCards = Context#play_context.cards_remaining - CardsUsed, + Efficiency = CardsUsed / max(1, Context#play_context.cards_remaining), + Efficiency * (1 + (20 - RemainingCards) / 20). + +%% 根据上下文调整分数 +adjust_score_for_context(Score, Play, Context, AIState) -> + RoleMultiplier = case AIState#ai_state.role of + dizhu -> 1.2; + nongmin -> 1.0 + end, + + StageMultiplier = case Context#play_context.game_stage of + early_game -> 0.9; + mid_game -> 1.0; + end_game -> 1.1 + end, + + Score * RoleMultiplier * StageMultiplier. + +%% 更新出牌后的状态 +update_after_play(Play, AIState) -> + {_, _, Cards} = Play, + NewHand = AIState#ai_state.hand_cards -- Cards, + NewPlayed = AIState#ai_state.played_cards ++ Cards, + + AIState#ai_state{ + hand_cards = NewHand, + played_cards = NewPlayed }. -generate_possible_plays(GameState) -> - MyCards = get_my_cards(GameState), - LastPlay = get_last_play(GameState), - generate_valid_plays(MyCards, LastPlay). - -evaluate_plays(Plays, AIState, Situation) -> - lists:map( - fun(Play) -> - Score = calculate_play_score(Play, AIState, Situation), - {Play, Score} - end, - Plays - ). - -calculate_play_score(Play, AIState, Situation) -> - Weights = AIState#ai_state.strategy_weights, - - ControlScore = evaluate_control_value(Play, Situation) * - maps:get(control_weight, Weights), - AttackScore = evaluate_attack_value(Play, Situation) * - maps:get(attack_weight, Weights), - DefenseScore = evaluate_defense_value(Play, Situation) * - maps:get(defense_weight, Weights), - RiskScore = evaluate_risk_value(Play, Situation) * - maps:get(risk_weight, Weights), - - ControlScore + AttackScore + DefenseScore + RiskScore. - -select_best_play(RatedPlays, AIState) -> - case AIState#ai_state.personality of - aggressive -> - select_aggressive(RatedPlays); - conservative -> - select_conservative(RatedPlays); - balanced -> - select_balanced(RatedPlays) - end. - -%% 策略选择函数 - -select_aggressive(RatedPlays) -> - % 倾向于选择得分最高的行动 - {Play, _Score} = lists:max(RatedPlays), - Play. - -select_conservative(RatedPlays) -> - % 倾向于选择风险较低的行动 - SafePlays = filter_safe_plays(RatedPlays), - case SafePlays of - [] -> select_balanced(RatedPlays); - _ -> select_from_safe_plays(SafePlays) - end. - -select_balanced(RatedPlays) -> - % 在得分和风险之间寻找平衡 - {Play, _Score} = select_balanced_play(RatedPlays), - Play. - -%% 评估函数 - -evaluate_hand_strength(GameState) -> - Cards = get_my_cards(GameState), - calculate_hand_value(Cards). - -evaluate_control(GameState) -> - % 评估是否控制局势 - LastPlay = get_last_play(GameState), - MyCards = get_my_cards(GameState), - can_control_game(MyCards, LastPlay). - -estimate_opponent_cards(GameState) -> - % 基于已出牌情况估计对手手牌 - PlayedCards = get_played_cards(GameState), - MyCards = get_my_cards(GameState), - estimate_remaining_cards(PlayedCards, MyCards). - -%% 知识库更新 - -update_knowledge(KnowledgeBase, GameHistory) -> - % 更新AI的知识库 - NewPatterns = extract_patterns(GameHistory), - merge_knowledge(KnowledgeBase, NewPatterns). - -extract_patterns(GameHistory) -> - % 从游戏历史中提取出牌模式 - lists:foldl( - fun(Play, Patterns) -> - Pattern = analyze_play_pattern(Play), - update_pattern_stats(Pattern, Patterns) - end, - #{}, - GameHistory - ). - -merge_knowledge(Old, New) -> - maps:merge_with( - fun(_Key, OldValue, NewValue) -> - update_knowledge_value(OldValue, NewValue) - end, - Old, - New - ). \ No newline at end of file +%% 辅助函数 + +normalize_probability(P) -> + max(0.0, min(1.0, P)). + +count_control_cards(Cards) -> + length([C || {V, _} = C <- Cards, V >= ?CARD_2]). + +calculate_position_bonus(AIState, GameState) -> + case get_position(AIState, GameState) of + first -> 1.2; + middle -> 1.0; + last -> 0.8 + end. + +get_position(AIState, GameState) -> + % 根据游戏状态判断位置 + first. % 简化实现,实际需要根据具体游戏状态判断 + +is_current_player(AIState, GameState) -> + % 判断是否当前玩家 + true. % 简化实现,实际需要根据具体游戏状态判断 + +has_active_play(GameState) -> + % 判断是否有活跃的出牌 + false. % 简化实现,实际需要根据具体游戏状态判断 + +should_beat(HandStrength, ControlLevel, CardsLeft, LastPlay) -> + BaseThreshold = 0.6, + StrengthFactor = HandStrength / 100, + ControlFactor = ControlLevel / 100, + CardsFactor = (20 - CardsLeft) / 20, + + PlayThreshold = BaseThreshold * (StrengthFactor + ControlFactor + CardsFactor) / 3, + evaluate_play_value(LastPlay) < PlayThreshold. + +evaluate_play_value({Type, Value, Cards}) -> + BaseValue = case Type of + ?CARD_TYPE_ROCKET -> 1.0; + ?CARD_TYPE_BOMB -> 0.9; + ?CARD_TYPE_PLANE -> 0.7; + ?CARD_TYPE_STRAIGHT -> 0.6; + ?CARD_TYPE_STRAIGHT_PAIR -> 0.5; + ?CARD_TYPE_THREE_TWO -> 0.4; + ?CARD_TYPE_THREE_ONE -> 0.3; + ?CARD_TYPE_THREE -> 0.25; + ?CARD_TYPE_PAIR -> 0.2; + ?CARD_TYPE_SINGLE -> 0.1 + end, + + ValueBonus = Value / ?CARD_JOKER_BIG, + CardCountFactor = length(Cards) / 10, + + BaseValue * (1 + ValueBonus) * (1 + CardCountFactor). + +%% 牌型生成函数 +generate_single_plays(Components) -> + [{?CARD_TYPE_SINGLE, Value, [Card]} || + {Value, Card} <- maps:get(singles, Components, [])]. + +generate_pair_plays(Components) -> + [{?CARD_TYPE_PAIR, Value, Cards} || + {Value, Cards} <- maps:get(pairs, Components, [])]. + +generate_triple_plays(Components) -> + Triples = maps:get(triples, Components, []), + BasicTriples = [{?CARD_TYPE_THREE, Value, Cards} || + {Value, Cards} <- Triples], + % 生成三带一和三带二 + generate_triple_combinations(Triples, Components). + +generate_sequence_plays(Components) -> + Sequences = maps:get(sequences, Components, []), + [{?CARD_TYPE_STRAIGHT, Value, Cards} || + {Value, Cards} <- Sequences]. + +generate_bomb_plays(Components) -> + [{?CARD_TYPE_BOMB, Value, Cards} || + {Value, Cards} <- maps:get(bombs, Components, [])]. + +generate_triple_combinations(Triples, Components) -> + Singles = maps:get(singles, Components, []), + Pairs = maps:get(pairs, Components, []), + + ThreeOne = generate_three_one(Triples, Singles), + ThreeTwo = generate_three_two(Triples, Pairs), + + ThreeOne ++ ThreeTwo. + +generate_three_one(Triples, Singles) -> + [{?CARD_TYPE_THREE_ONE, TripleValue, TripleCards ++ [SingleCard]} || + {TripleValue, TripleCards} <- Triples, + {SingleValue, SingleCard} <- Singles, + SingleValue =/= TripleValue]. + +generate_three_two(Triples, Pairs) -> + [{?CARD_TYPE_THREE_TWO, TripleValue, TripleCards ++ PairCards} || + {TripleValue, TripleCards} <- Triples, + {PairValue, PairCards} <- Pairs, + PairValue =/= TripleValue]. + +%% 查找特定牌型 +find_greater_plays(Cards, Type, MinValue) -> + Components = analyze_components(Cards), + case Type of + ?CARD_TYPE_SINGLE -> + find_greater_singles(Components, MinValue); + ?CARD_TYPE_PAIR -> + find_greater_pairs(Components, MinValue); + ?CARD_TYPE_THREE -> + find_greater_triples(Components, MinValue); + ?CARD_TYPE_THREE_ONE -> + find_greater_three_one(Components, MinValue); + ?CARD_TYPE_THREE_TWO -> + find_greater_three_two(Components, MinValue); + ?CARD_TYPE_STRAIGHT -> + find_greater_straight(Components, MinValue); + ?CARD_TYPE_STRAIGHT_PAIR -> + find_greater_straight_pair(Components, MinValue); + ?CARD_TYPE_PLANE -> + find_greater_plane(Components, MinValue); + ?CARD_TYPE_BOMB -> + find_greater_bomb(Components, MinValue); + _ -> [] + end. + +find_greater_singles(Components, MinValue) -> + [{?CARD_TYPE_SINGLE, Value, [Card]} || + {Value, Card} <- maps:get(singles, Components, []), + Value > MinValue]. + +find_greater_pairs(Components, MinValue) -> + [{?CARD_TYPE_PAIR, Value, Cards} || + {Value, Cards} <- maps:get(pairs, Components, []), + Value > MinValue]. + +find_greater_triples(Components, MinValue) -> + [{?CARD_TYPE_THREE, Value, Cards} || + {Value, Cards} <- maps:get(triples, Components, []), + Value > MinValue]. + +find_greater_three_one(Components, MinValue) -> + Triples = [{V, C} || {V, C} <- maps:get(triples, Components, []), + V > MinValue], + Singles = maps:get(singles, Components, []), + generate_three_one(Triples, Singles). + +find_greater_three_two(Components, MinValue) -> + Triples = [{V, C} || {V, C} <- maps:get(triples, Components, []), + V > MinValue], + Pairs = maps:get(pairs, Components, []), + generate_three_two(Triples, Pairs). + +find_greater_straight(Components, MinValue) -> + Sequences = maps:get(sequences, Components, []), + [{?CARD_TYPE_STRAIGHT, Value, Cards} || + {Value, Cards} <- Sequences, + Value > MinValue]. + +find_greater_straight_pair(Components, MinValue) -> + Sequences = find_pair_sequences(Components), + [{?CARD_TYPE_STRAIGHT_PAIR, Value, Cards} || + {Value, Cards} <- Sequences, + Value > MinValue]. + +find_greater_plane(Components, MinValue) -> + Planes = find_planes(Components), + [{?CARD_TYPE_PLANE, Value, Cards} || + {Value, Cards} <- Planes, + Value > MinValue]. + +find_greater_bomb(Components, MinValue) -> + [{?CARD_TYPE_BOMB, Value, Cards} || + {Value, Cards} <- maps:get(bombs, Components, []), + Value > MinValue]. + +find_bomb_plays(Cards) -> + Components = analyze_components(Cards), + [{?CARD_TYPE_BOMB, Value, Cards} || + {Value, Cards} <- maps:get(bombs, Components, [])]. + +find_rocket_play(Cards) -> + Components = analyze_components(Cards), + case find_rocket(Components) of + {ok, Rocket} -> [Rocket]; + _ -> [] + end. + +find_rocket(Components) -> + case {find_card(?CARD_JOKER_SMALL, Components), + find_card(?CARD_JOKER_BIG, Components)} of + {{ok, Small}, {ok, Big}} -> + {ok, {?CARD_TYPE_ROCKET, ?CARD_JOKER_BIG, [Small, Big]}}; + _ -> + false + end. + +find_card(Value, Components) -> + Singles = maps:get(singles, Components, []), + case lists:keyfind(Value, 1, Singles) of + {Value, Card} -> {ok, Card}; + _ -> false + end. + +find_pair_sequences(Components) -> + Pairs = maps:get(pairs, Components, []), + find_consecutive_pairs(lists:sort(Pairs), []). + +find_planes(Components) -> + Triples = maps:get(triples, Components, []), + find_consecutive_triples(lists:sort(Triples), []). + +find_consecutive_pairs([], Acc) -> lists:reverse(Acc); +find_consecutive_pairs([{V1, Cards1} | Rest], Acc) -> + case find_consecutive_pair_sequence(V1, Cards1, Rest) of + {Sequence, NewRest} when length(Sequence) >= 3 -> + find_consecutive_pairs(NewRest, [{V1, Sequence} | Acc]); + _ -> + find_consecutive_pairs(Rest, Acc) + end. + +find_consecutive_pair_sequence(Value, Cards, Rest) -> + find_consecutive_pair_sequence(Value, Cards, Rest, [Cards]). + +find_consecutive_pair_sequence(Value, _, [], Acc) -> + {lists:flatten(lists:reverse(Acc)), []}; +find_consecutive_pair_sequence(Value, _, [{NextValue, NextCards} | Rest], Acc) + when NextValue =:= Value + 1 -> + find_consecutive_pair_sequence(NextValue, NextCards, Rest, [NextCards | Acc]); +find_consecutive_pair_sequence(_, _, Rest, Acc) -> + {lists:flatten(lists:reverse(Acc)), Rest}. + +find_consecutive_triples([], Acc) -> lists:reverse(Acc); +find_consecutive_triples([{V1, Cards1} | Rest], Acc) -> + case find_consecutive_triple_sequence(V1, Cards1, Rest) of + {Sequence, NewRest} when length(Sequence) >= 6 -> + find_consecutive_triples(NewRest, [{V1, Sequence} | Acc]); + _ -> + find_consecutive_triples(Rest, Acc) + end. + +find_consecutive_triple_sequence(Value, Cards, Rest) -> + find_consecutive_triple_sequence(Value, Cards, Rest, [Cards]). + +find_consecutive_triple_sequence(Value, _, [], Acc) -> + {lists:flatten(lists:reverse(Acc)), []}; +find_consecutive_triple_sequence(Value, _, [{NextValue, NextCards} | Rest], Acc) + when NextValue =:= Value + 1 -> + find_consecutive_triple_sequence(NextValue, NextCards, Rest, [NextCards | Acc]); +find_consecutive_triple_sequence(_, _, Rest, Acc) -> + {lists:flatten(lists:reverse(Acc)), Rest}. + +%% 计算早期、中期和末期的节奏分数 +calculate_early_tempo({Type, Value, _}, _Context) -> + case Type of + ?CARD_TYPE_SINGLE when Value < ?CARD_2 -> 0.8; + ?CARD_TYPE_PAIR when Value < ?CARD_2 -> 0.7; + ?CARD_TYPE_STRAIGHT -> 0.9; + ?CARD_TYPE_STRAIGHT_PAIR -> 0.85; + _ -> 0.5 + end. + +calculate_mid_tempo({Type, Value, _}, _Context) -> + case Type of + ?CARD_TYPE_THREE_ONE -> 0.8; + ?CARD_TYPE_THREE_TWO -> 0.85; + ?CARD_TYPE_PLANE -> 0.9; + ?CARD_TYPE_BOMB -> 0.7; + _ -> 0.6 + end. + +calculate_end_tempo({Type, Value, _}, Context) -> + CardsLeft = Context#play_context.cards_remaining, + case Type of + ?CARD_TYPE_BOMB -> 0.9; + ?CARD_TYPE_ROCKET -> 1.0; + _ when CardsLeft =< 4 -> 0.95; + _ -> 0.7 + end. + +%% 过滤候选出牌 +filter_candidates(Plays, Context) -> + case Context#play_context.game_stage of + early_game -> + filter_early_game_plays(Plays, Context); + mid_game -> + filter_mid_game_plays(Plays, Context); + end_game -> + filter_end_game_plays(Plays, Context) + end. + +filter_early_game_plays(Plays, _Context) -> + % 早期游戏倾向于出小牌,保留炸弹 + [Play || {Type, Value, _} = Play <- Plays, + Type =/= ?CARD_TYPE_BOMB orelse Value >= ?CARD_2]. + +filter_mid_game_plays(Plays, Context) -> + % 中期游戏根据局势决定是否使用炸弹 + case Context#play_context.control_factor < 0.5 of + true -> Plays; + false -> + [Play || {Type, _, _} = Play <- Plays, + Type =/= ?CARD_TYPE_BOMB] + end. + +filter_end_game_plays(Plays, _Context) -> + % 末期游戏可以使用任何牌型 + Plays. + +%% 对手牌进行分组 +group_cards_by_value(Cards) -> + lists:foldl(fun(Card, Acc) -> + {Value, _} = Card, + maps:update_with(Value, + fun(List) -> [Card|List] end, + [Card], + Acc) + end, maps:new(), Cards). \ No newline at end of file diff --git a/src/ai_optimizer.erl b/src/ai_optimizer.erl index 465d18c..7186aca 100644 --- a/src/ai_optimizer.erl +++ b/src/ai_optimizer.erl @@ -1,6 +1,42 @@ -module(ai_optimizer). -export([optimize_ai_system/2, tune_parameters/2]). +-include("../include/game_records.hrl"). + +%% 分析性能指标 +analyze_performance_metrics(Metrics) -> + % 分析AI性能指标 + #{ + win_rate => maps:get(win_rate, Metrics, 0.0), + avg_decision_time => maps:get(avg_decision_time, Metrics, 0.0), + strategy_effectiveness => maps:get(strategy_effectiveness, Metrics, 0.0), + learning_rate => maps:get(learning_rate, Metrics, 0.01) + }. + +%% 优化决策系统 +optimize_decision_system(AIState, PerformanceAnalysis) -> + % 根据性能分析优化决策系统 + CurrentSystem = AIState#ai_state.strategy_model, + WinRate = maps:get(win_rate, PerformanceAnalysis), + + % 根据胜率调整策略参数 + AdjustedParams = adjust_strategy_parameters(CurrentSystem, WinRate), + + % 返回优化后的决策系统 + CurrentSystem#{parameters => AdjustedParams}. + +%% 优化学习系统 +optimize_learning_system(AIState, PerformanceAnalysis) -> + % 根据性能分析优化学习系统 + CurrentSystem = AIState#ai_state.learning_model, + LearningRate = maps:get(learning_rate, PerformanceAnalysis), + + % 调整学习率 + AdjustedLearningRate = adjust_learning_rate(LearningRate, PerformanceAnalysis), + + % 返回优化后的学习系统 + CurrentSystem#{learning_rate => AdjustedLearningRate}. + optimize_ai_system(AIState, Metrics) -> % 分析性能指标 PerformanceAnalysis = analyze_performance_metrics(Metrics), @@ -13,8 +49,8 @@ optimize_ai_system(AIState, Metrics) -> % 更新AI状态 AIState#ai_state{ - decision_system = OptimizedDecisionSystem, - learning_system = OptimizedLearningSystem + strategy_model = OptimizedDecisionSystem, + learning_model = OptimizedLearningSystem }. tune_parameters(Parameters, Performance) -> @@ -26,4 +62,54 @@ tune_parameters(Parameters, Performance) -> end, Parameters ), - maps:from_list(OptimizedParams). \ No newline at end of file + maps:from_list(OptimizedParams). + +%% 内部辅助函数 + +%% 调整策略参数 +adjust_strategy_parameters(Strategy, WinRate) -> + CurrentParams = maps:get(parameters, Strategy, #{}), + + % 根据胜率调整参数 + maps:map( + fun(ParamName, ParamValue) -> + adjust_parameter(ParamName, ParamValue, #{win_rate => WinRate}) + end, + CurrentParams + ). + +%% 调整学习率 +adjust_learning_rate(CurrentRate, PerformanceAnalysis) -> + % 根据性能指标调整学习率 + Effectiveness = maps:get(strategy_effectiveness, PerformanceAnalysis, 0.5), + + % 如果效果好,减小学习率;如果效果差,增大学习率 + case Effectiveness of + E when E > 0.7 -> max(0.001, CurrentRate * 0.9); + E when E < 0.3 -> min(0.1, CurrentRate * 1.1); + _ -> CurrentRate + end. + +%% 调整单个参数 +adjust_parameter(Param, Value, Performance) -> + % 根据性能调整参数值 + WinRate = maps:get(win_rate, Performance, 0.5), + + % 简单的参数调整逻辑 + case Param of + risk_factor when WinRate < 0.4 -> + % 胜率低时降低风险 + max(0.1, Value * 0.9); + risk_factor when WinRate > 0.6 -> + % 胜率高时可以增加风险 + min(0.9, Value * 1.1); + aggressive_factor when WinRate < 0.4 -> + % 胜率低时降低激进程度 + max(0.1, Value * 0.9); + aggressive_factor when WinRate > 0.6 -> + % 胜率高时可以更激进 + min(0.9, Value * 1.1); + _ -> + % 其他参数保持不变 + Value + end. \ No newline at end of file diff --git a/src/ai_player.erl b/src/ai_player.erl index dfb952f..0774833 100644 --- a/src/ai_player.erl +++ b/src/ai_player.erl @@ -155,4 +155,50 @@ get_optimal_play(Cards, LastPlay) -> find_optimal_combination(Cards); _ -> find_minimum_bigger_combination(Cards, LastPlay) - end. \ No newline at end of file + end. + +% 查找最大牌型组合 +find_biggest_combination(Cards) -> + % 按照牌型优先级查找:火箭 > 炸弹 > 顺子 > 三带二 > 三带一 > 对子 > 单牌 + case find_rocket(Cards) of + {ok, Rocket} -> Rocket; + _ -> + case find_bomb(Cards) of + {ok, Bomb} -> Bomb; + _ -> find_best_normal_combination(Cards) + end + end. + +% 查找比上家更大的牌型 +find_bigger_combination(Cards, LastPlay) -> + LastType = card_rules:get_card_type(LastPlay), + case LastType of + rocket -> []; % 火箭最大,无法压过 + bomb -> find_bigger_bomb(Cards, LastPlay); + _ -> find_bigger_normal_combination(Cards, LastPlay, LastType) + end. + +% 查找最优组合 +find_optimal_combination(Cards) -> + % 简化实现:优先出单牌、对子、三带一等小牌型 + case find_single(Cards) of + {ok, Single} -> Single; + _ -> find_biggest_combination(Cards) + end. + +% 查找最小的能压过上家的牌 +find_minimum_bigger_combination(Cards, LastPlay) -> + % 简化实现:找到能压过上家的最小牌组合 + Bigger = find_bigger_combination(Cards, LastPlay), + case Bigger of + [] -> []; + _ -> Bigger + end. + +% 处理代码热更新 +code_change(_OldVsn, State, _Extra) -> + {ok, State}. + +% 处理进程终止 +terminate(_Reason, _State) -> + ok. \ No newline at end of file diff --git a/src/ai_strategy.erl b/src/ai_strategy.erl index ae42d1c..fc5c50b 100644 --- a/src/ai_strategy.erl +++ b/src/ai_strategy.erl @@ -1,6 +1,6 @@ -module(ai_strategy). --export([initialize_strategy/0, update_strategy/2, make_decision/2, - analyze_game_state/1, evaluate_play/3]). +-export([initialize_strategy/0, update_strategy/2, make_decision/2, analyze_game_state/1, evaluate_play/3, + generate_possible_plays/1, generate_singles/1, generate_pairs/1, generate_triples/1, generate_straights/1, generate_bombs/1]). -record(game_state, { hand_cards, % 手牌 @@ -61,100 +61,162 @@ make_decision(Strategy, GameState) -> EvaluatedPlays = evaluate_all_plays(PossiblePlays, Strategy, GameState), select_best_play(EvaluatedPlays, Strategy, GameState). -%% 评估所有可能的出牌 -evaluate_all_plays(Plays, Strategy, GameState) -> - lists:map( - fun(Play) -> - Score = evaluate_play(Play, Strategy, GameState), - {Play, Score} - end, - Plays - ). +%% 评估所有可能的出牌组合生成函数的具体实现 +generate_singles(Cards) -> + lists:map(fun(C) -> [C] end, Cards). -%% 评估单个出牌 -evaluate_play(Play, Strategy, GameState) -> - Weights = get_stage_weights(Strategy, GameState#game_state.stage), - - ControlScore = evaluate_control(Play, GameState) * maps:get(control_weight, Weights), - ComboScore = evaluate_combo(Play, GameState) * maps:get(combo_weight, Weights), - DefensiveScore = evaluate_defensive(Play, GameState) * maps:get(defensive_weight, Weights), - RiskScore = evaluate_risk(Play, GameState) * maps:get(risk_weight, Weights), - - ControlScore + ComboScore + DefensiveScore + RiskScore. +generate_pairs(Cards) -> + UniqueCards = lists:usort(Cards), + [lists:duplicate(2, C) || C <- UniqueCards, length([C || C1 <- Cards, C1 == C]) >= 2]. -%% 选择最佳出牌 -select_best_play(EvaluatedPlays, Strategy, GameState) -> - case should_apply_randomness(Strategy, GameState) of - true -> - apply_randomness(EvaluatedPlays); - false -> - {Play, _Score} = lists:max(EvaluatedPlays), - Play - end. +generate_triples(Cards) -> + UniqueCards = lists:usort(Cards), + [lists:duplicate(3, C) || C <- UniqueCards, length([C || C1 <- Cards, C1 == C]) >= 3]. -%% 内部辅助函数 +generate_straights(Cards) -> + Sorted = lists:sort(Cards), + find_straights(Sorted, 5). -determine_game_stage(#game_state{remaining_cards = Remaining}) -> - case Remaining of - N when N > 15 -> early_game; - N when N > 8 -> mid_game; +find_straights(_, MinLen) when MinLen < 5 -> []; +find_straights([H|T], Len) -> + case lists:seq(H, H+Len-1) of + Seq when length(Seq) == Len -> [Seq]; + _ -> find_straights(T, Len) + end. + +generate_bombs(Cards) -> + [lists:duplicate(4, C) || C <- lists:usort(Cards), length([C || C1 <- Cards, C1 == C]) >= 4]. + +%% 生成所有可能的出牌 +generate_possible_plays(Cards) -> + Singles = generate_singles(Cards), + Pairs = generate_pairs(Cards), + Triples = generate_triples(Cards), + Straights = generate_straights(Cards), + Bombs = generate_bombs(Cards), + Singles ++ Pairs ++ Triples ++ Straights ++ Bombs. + +%% 游戏阶段判断函数 +determine_game_stage(GameState) -> + % 简单根据剩余牌数判断阶段 + case GameState#game_state.remaining_cards of + N when N > 20 -> early_game; + N when N > 10 -> mid_game; _ -> late_game end. +%% 调整权重 adjust_weights(Strategy, Stage, GameState) -> - CurrentWeights = maps:get(Stage, maps:get(weights, Strategy)), - AdaptationRate = maps:get(adaptation_rate, Strategy), - - % 基于游戏状态调整权重 - adjust_weight_based_on_state(CurrentWeights, GameState, AdaptationRate). + CurrentWeights = maps:get(weights, Strategy), + StageWeights = maps:get(Stage, CurrentWeights), + % 根据游戏状态微调权重 + maps:map( + fun(Key, Value) -> + AdjustmentFactor = calculate_adjustment_factor(Key, GameState), + Value * AdjustmentFactor + end, + StageWeights + ). +%% 计算调整因子 +calculate_adjustment_factor(WeightKey, GameState) -> + % 简单实现,根据不同权重类型返回调整因子 + case WeightKey of + control_weight when GameState#game_state.control_status -> 1.2; + control_weight -> 0.8; + _ -> 1.0 + end. + +%% 更新经验 update_experience(Strategy, GameState) -> Experience = maps:get(experience, Strategy), - GamePattern = extract_game_pattern(GameState), - - maps:update_with( - GamePattern, - fun(Count) -> Count + 1 end, - 1, - Experience - ). + % 简单实现,记录游戏状态特征 + Experience#{ + GameState#game_state.stage => maps:get(GameState#game_state.stage, Experience, 0) + 1 + }. -evaluate_control(Play, GameState) -> - case Play of - pass -> 0.0; - _ -> - RemainingControl = calculate_remaining_control(GameState), - PlayStrength = game_logic:calculate_card_value(Play), - RemainingControl * PlayStrength / 100.0 - end. +%% 评估所有出牌 +evaluate_all_plays(Plays, Strategy, GameState) -> + [{Play, evaluate_play(Play, Strategy, GameState)} || Play <- Plays]. -evaluate_combo(Play, GameState) -> - RemainingCombos = count_remaining_combos(GameState#game_state.hand_cards -- Play), - case RemainingCombos of - 0 -> 1.0; - _ -> 0.8 * (1 - 1/RemainingCombos) +%% 评估单个出牌 +evaluate_play(Play, Strategy, GameState) -> + % 简单评估逻辑 + ControlScore = evaluate_control_impact(Play, GameState), + ComboScore = evaluate_combo_potential(Play, GameState), + DefensiveScore = evaluate_defensive_value(Play, GameState), + + % 获取当前阶段权重 + Stage = GameState#game_state.stage, + Weights = maps:get(Stage, maps:get(weights, Strategy)), + + % 计算加权得分 + ControlWeight = maps:get(control_weight, Weights), + ComboWeight = maps:get(combo_weight, Weights), + DefensiveWeight = maps:get(defensive_weight, Weights), + + (ControlScore * ControlWeight) + (ComboScore * ComboWeight) + (DefensiveScore * DefensiveWeight). + +%% 评估控制影响 +evaluate_control_impact(Play, GameState) -> + % 简单实现 + case length(Play) of + 1 -> 0.3; % 单牌控制性较低 + 2 -> 0.5; % 对子中等 + 3 -> 0.7; % 三张较高 + 4 when Play == [Play|_] -> 1.0; % 炸弹最高 + _ -> 0.6 % 其他牌型 end. -evaluate_defensive(Play, GameState) -> - case GameState#game_state.player_position of - farmer -> - evaluate_farmer_defensive(Play, GameState); - landlord -> - evaluate_landlord_defensive(Play, GameState) - end. +%% 评估连招潜力 +evaluate_combo_potential(Play, GameState) -> + % 简单实现 + 0.5. % 默认中等潜力 -evaluate_risk(Play, GameState) -> - case is_risky_play(Play, GameState) of - true -> 0.3; - false -> 1.0 +%% 评估防守价值 +evaluate_defensive_value(Play, GameState) -> + % 简单实现 + 0.5. % 默认中等防守价值 + +%% 选择最佳出牌 +select_best_play(EvaluatedPlays, Strategy, GameState) -> + case EvaluatedPlays of + [] -> pass; + _ -> + {BestPlay, _Score} = lists:max(fun({_, ScoreA}, {_, ScoreB}) -> ScoreA >= ScoreB end, EvaluatedPlays), + BestPlay end. -should_apply_randomness(Strategy, GameState) -> - ExperienceCount = maps:size(maps:get(experience, Strategy)), - ExperienceCount < 1000 orelse is_close_game(GameState). +%% 分析游戏状态 +analyze_game_state(GameState) -> + #{ + hand_strength => calculate_hand_strength(GameState#game_state.hand_cards), + control_level => calculate_control_level(GameState), + stage => GameState#game_state.stage, + position_advantage => calculate_position_advantage(GameState) + }. + +%% 计算手牌强度 +calculate_hand_strength(Cards) -> + % 简单实现,根据牌型数量和质量评估 + Bombs = length(generate_bombs(Cards)), + Triples = length(generate_triples(Cards)), + Pairs = length(generate_pairs(Cards)), + + (Bombs * 0.5) + (Triples * 0.3) + (Pairs * 0.2). + +%% 计算控制水平 +calculate_control_level(GameState) -> + % 简单实现 + case GameState#game_state.control_status of + true -> 0.8; + false -> 0.2 + end. -apply_randomness(EvaluatedPlays) -> - RandomFactor = 0.1, - Plays = [{Play, Score + (rand:uniform() * RandomFactor)} || {Play, Score} <- EvaluatedPlays], - {SelectedPlay, _} = lists:max(Plays), - SelectedPlay. \ No newline at end of file +%% 计算位置优势 +calculate_position_advantage(GameState) -> + % 简单实现,地主位置优势较大 + case GameState#game_state.player_position of + landlord -> 0.7; + farmer -> 0.3 + end. \ No newline at end of file diff --git a/src/ai_test.erl b/src/ai_test.erl deleted file mode 100644 index 78d53e7..0000000 --- a/src/ai_test.erl +++ /dev/null @@ -1,58 +0,0 @@ --module(ai_test). --export([run_test/0]). - -run_test() -> - % 启动所有必要的服务 - {ok, DL} = deep_learning:start_link(), - {ok, PC} = parallel_compute:start_link(), - {ok, PM} = performance_monitor:start_link(), - {ok, VS} = visualization:start_link(), - - % 创建测试数据 - TestData = create_test_data(), - - % 训练网络 - {ok, Network} = deep_learning:train_network(test_network, TestData), - - % 进行预测 - TestInput = prepare_test_input(), - {ok, Prediction} = deep_learning:predict(test_network, TestInput), - - % 监控性能 - {ok, MonitorId} = performance_monitor:start_monitoring(test_network), - - % 等待一些时间收集数据 - timer:sleep(5000), - - % 获取性能数据 - {ok, Metrics} = performance_monitor:get_metrics(MonitorId), - - % 创建可视化 - {ok, ChartId} = visualization:create_chart(line_chart, Metrics), - - % 导出结果 - {ok, Report} = performance_monitor:generate_report(MonitorId), - {ok, Chart} = visualization:export_chart(ChartId, png), - - % 清理资源 - ok = performance_monitor:stop_monitoring(MonitorId), - - % 返回测试结果 - #{ - prediction => Prediction, - metrics => Metrics, - report => Report, - chart => Chart - }. - -% 辅助函数 -create_test_data() -> - [ - {[1,2,3], [4]}, - {[2,3,4], [5]}, - {[3,4,5], [6]}, - {[4,5,6], [7]} - ]. - -prepare_test_input() -> - [5,6,7]. \ No newline at end of file diff --git a/src/auto_player.erl b/src/auto_player.erl index 67ceb5d..59387ad 100644 --- a/src/auto_player.erl +++ b/src/auto_player.erl @@ -103,4 +103,12 @@ can_beat_play(Cards, LastPlay) -> case card_rules:find_valid_plays(Cards, LastPlay) of [] -> false; [BestPlay|_] -> {true, BestPlay} - end. \ No newline at end of file + end. + +%% 处理代码热更新 +code_change(_OldVsn, State, _Extra) -> + {ok, State}. + +%% 处理进程终止 +terminate(_Reason, _State) -> + ok. \ No newline at end of file diff --git a/src/card_checker.erl b/src/card_checker.erl new file mode 100644 index 0000000..1df377f --- /dev/null +++ b/src/card_checker.erl @@ -0,0 +1,297 @@ +-module(card_checker). +-include("card_types.hrl"). + +%% API exports +-export([ + check_card_type/1, % 检查牌型 + compare_cards/2, % 比较两手牌 + is_valid_play/2, % 验证出牌是否合法 + format_cards/1, % 格式化牌的显示 + validate_cards/1 % 验证牌的合法性 +]). + +-type card() :: {integer(), atom()}. +-type cards() :: [card()]. +-type card_type() :: integer(). +-type card_value() :: integer(). + +%% API 函数实现 + +%% @doc 检查一组牌的类型 +-spec check_card_type(cards()) -> {ok, card_type(), card_value()} | {error, invalid_type}. +check_card_type(Cards) when length(Cards) > 0 -> + SortedCards = sort_cards(Cards), + case identify_type(SortedCards) of + {ok, _Type, _Value} = Result -> Result; + Error -> Error + end; +check_card_type([]) -> + {error, invalid_type}. + +%% @doc 比较两手牌的大小 +-spec compare_cards(cards(), cards()) -> greater | lesser | invalid. +compare_cards(Cards1, Cards2) -> + case {check_card_type(Cards1), check_card_type(Cards2)} of + {{ok, Type1, Value1}, {ok, Type2, Value2}} -> + compare_types_and_values(Type1, Value1, Type2, Value2); + _ -> + invalid + end. + +%% @doc 验证当前出牌是否合法(相对于上一手牌) +-spec is_valid_play(cards(), cards() | undefined) -> boolean(). +is_valid_play(_NewCards, undefined) -> + true; +is_valid_play(NewCards, LastCards) -> + case {check_card_type(NewCards), check_card_type(LastCards)} of + {{ok, Type1, Value1}, {ok, Type2, Value2}} -> + can_beat(Type1, Value1, Type2, Value2); + _ -> + false + end. + +%% @doc 格式化牌的显示 +-spec format_cards(cards()) -> string(). +format_cards(Cards) -> + lists:map(fun format_card/1, Cards). + +%% @doc 验证牌的合法性 +-spec validate_cards(cards()) -> boolean(). +validate_cards(Cards) -> + is_valid_card_list(Cards) andalso + no_duplicate_cards(Cards) andalso + all_cards_valid(Cards). + +%% 内部函数 + +%% @private 识别牌型 +identify_type(Cards) -> + case length(Cards) of + 1 -> {ok, ?CARD_TYPE_SINGLE, get_card_value(hd(Cards))}; + 2 -> check_pair_or_rocket(Cards); + 3 -> check_three(Cards); + 4 -> check_four_or_three_one(Cards); + 5 -> check_three_two(Cards); + _ -> check_sequence_types(Cards) + end. + +%% @private 检查对子或火箭 +check_pair_or_rocket([{V1, S1}, {V2, S2}] = Cards) -> + if + V1 =:= V2 -> {ok, ?CARD_TYPE_PAIR, V1}; + V1 =:= ?CARD_JOKER_BIG, V2 =:= ?CARD_JOKER_SMALL -> + {ok, ?CARD_TYPE_ROCKET, ?CARD_JOKER_BIG}; + true -> {error, invalid_type} + end. + +%% @private 检查三张 +check_three([{V, _}, {V, _}, {V, _}]) -> + {ok, ?CARD_TYPE_THREE, V}; +check_three(_) -> + {error, invalid_type}. + +%% @private 检查四张或三带一 +check_four_or_three_one(Cards) -> + Values = [V || {V, _} <- Cards], + case count_values(Values) of + [{V, 4}] -> + {ok, ?CARD_TYPE_BOMB, V}; + [{V, 3}, {_, 1}] -> + {ok, ?CARD_TYPE_THREE_ONE, V}; + [{_, 1}, {V, 3}] -> + {ok, ?CARD_TYPE_THREE_ONE, V}; + _ -> + {error, invalid_type} + end. + +%% @private 检查三带二 +check_three_two(Cards) -> + Values = [V || {V, _} <- Cards], + case count_values(Values) of + [{V, 3}, {_, 2}] -> {ok, ?CARD_TYPE_THREE_TWO, V}; + [{_, 2}, {V, 3}] -> {ok, ?CARD_TYPE_THREE_TWO, V}; + _ -> {error, invalid_type} + end. + +%% @private 检查顺子类型 +check_sequence_types(Cards) -> + case length(Cards) of + L when L >= 5 -> + Values = [V || {V, _} <- Cards], + cond_check_sequence_types(Values, Cards); + _ -> + {error, invalid_type} + end. + +%% @private 条件检查顺子类型 +cond_check_sequence_types(Values, Cards) -> + case is_straight(Values) of + true -> + {ok, ?CARD_TYPE_STRAIGHT, lists:max(Values)}; + false -> + case is_straight_pairs(Values) of + true -> + {ok, ?CARD_TYPE_STRAIGHT_PAIR, lists:max(Values)}; + false -> + check_plane_types(Cards) + end + end. + +%% @private 检查飞机类型 +check_plane_types(Cards) -> + Values = [V || {V, _} <- Cards], + case check_plane_pattern(Values) of + {ok, MainValue, WithWings} -> + PlaneType = case WithWings of + true -> ?CARD_TYPE_PLANE_WINGS; + false -> ?CARD_TYPE_PLANE + end, + {ok, PlaneType, MainValue}; + false -> + {error, invalid_type} + end. + +%% @private 检查是否是顺子 +is_straight(Values) -> + SortedVals = lists:sort(Values), + length(SortedVals) >= 5 andalso + lists:max(SortedVals) < ?CARD_2 andalso + is_consecutive(SortedVals). + +%% @private 检查是否是连对 +is_straight_pairs(Values) -> + case count_values(Values) of + Pairs when length(Pairs) >= 3 -> + PairValues = [V || {V, 2} <- Pairs], + length(PairValues) * 2 =:= length(Values) andalso + lists:max(PairValues) < ?CARD_2 andalso + is_consecutive(lists:sort(PairValues)); + _ -> + false + end. + +%% @private 检查飞机模式 +check_plane_pattern(Values) -> + Counts = count_values(Values), + ThreeCounts = [{V, C} || {V, C} <- Counts, C =:= 3], + case length(ThreeCounts) >= 2 of + true -> + ThreeValues = [V || {V, _} <- ThreeCounts], + SortedThrees = lists:sort(ThreeValues), + case is_consecutive(SortedThrees) of + true -> + MainValue = lists:max(SortedThrees), + HasWings = length(Values) > length(ThreeValues) * 3, + {ok, MainValue, HasWings}; + false -> + false + end; + false -> + false + end. + +%% @private 检查是否连续 +is_consecutive([]) -> true; +is_consecutive([_]) -> true; +is_consecutive([A,B|Rest]) -> + case B - A of + 1 -> is_consecutive([B|Rest]); + _ -> false + end. + +%% @private 统计值的出现次数 +count_values(Values) -> + Sorted = lists:sort(Values), + count_values(Sorted, [], 1). + +count_values([], Acc, _) -> + lists:reverse(Acc); +count_values([V], Acc, Count) -> + lists:reverse([{V, Count}|Acc]); +count_values([V,V|Rest], Acc, Count) -> + count_values([V|Rest], Acc, Count + 1); +count_values([V1,V2|Rest], Acc, Count) -> + count_values([V2|Rest], [{V1, Count}|Acc], 1). + +%% @private 比较类型和值 +compare_types_and_values(Type, Value, Type, Value2) -> + if + Value > Value2 -> greater; + true -> lesser + end; +compare_types_and_values(?CARD_TYPE_ROCKET, _, _, _) -> + greater; +compare_types_and_values(_, _, ?CARD_TYPE_ROCKET, _) -> + lesser; +compare_types_and_values(?CARD_TYPE_BOMB, Value1, Type2, Value2) -> + case Type2 of + ?CARD_TYPE_BOMB when Value1 > Value2 -> greater; + ?CARD_TYPE_BOMB -> lesser; + _ -> greater + end; +compare_types_and_values(Type1, _, ?CARD_TYPE_BOMB, _) -> + lesser; +compare_types_and_values(_, _, _, _) -> + invalid. + +%% @private 检查是否能打过上一手牌 +can_beat(Type1, Value1, Type2, Value2) -> + case {Type1, Type2} of + {Same, Same} -> Value1 > Value2; + {?CARD_TYPE_ROCKET, _} -> true; + {?CARD_TYPE_BOMB, OtherType} when OtherType =/= ?CARD_TYPE_ROCKET -> true; + _ -> false + end. + +%% @private 格式化单张牌 +format_card({Value, Suit}) -> + SuitStr = case Suit of + hearts -> "♥"; + diamonds -> "♦"; + clubs -> "♣"; + spades -> "♠"; + joker -> "J" + end, + ValueStr = case Value of + ?CARD_JOKER_BIG -> "BJ"; + ?CARD_JOKER_SMALL -> "SJ"; + ?CARD_2 -> "2"; + ?CARD_A -> "A"; + ?CARD_K -> "K"; + ?CARD_Q -> "Q"; + ?CARD_J -> "J"; + 10 -> "10"; + N when N >= 3, N =< 9 -> integer_to_list(N) + end, + SuitStr ++ ValueStr. + +%% @private 检查是否是有效的牌列表 +is_valid_card_list(Cards) -> + is_list(Cards) andalso length(Cards) > 0. + +%% @private 检查是否有重复的牌 +no_duplicate_cards(Cards) -> + length(lists:usort(Cards)) =:= length(Cards). + +%% @private 检查所有牌是否合法 +all_cards_valid(Cards) -> + lists:all(fun is_valid_card/1, Cards). + +%% @private 检查单张牌是否合法 +is_valid_card({Value, Suit}) -> + (Value >= ?CARD_3 andalso Value =< ?CARD_2 andalso + lists:member(Suit, [hearts, diamonds, clubs, spades])) orelse + (Value >= ?CARD_JOKER_SMALL andalso Value =< ?CARD_JOKER_BIG andalso + Suit =:= joker). + +%% @private 排序牌 +sort_cards(Cards) -> + lists:sort(fun({V1, S1}, {V2, S2}) -> + if + V1 =:= V2 -> S1 =< S2; + true -> V1 =< V2 + end + end, Cards). + +%% @private 获取牌的值 +get_card_value({Value, _}) -> Value. \ No newline at end of file diff --git a/src/card_rules.erl b/src/card_rules.erl index 5148119..b87d986 100644 --- a/src/card_rules.erl +++ b/src/card_rules.erl @@ -47,6 +47,17 @@ check_sequence(Cards) -> false -> check_other_types(Cards) end. +%% 检查其他牌型 +check_other_types(Cards) -> + Grouped = group_cards(Cards), + case Grouped of + [{_, 3}, {_, 2}] -> three_two; % 三带二 + [{_, 4}, {_, 1}] -> four_one; % 四带一 + [{_, 4}, {_, 2}] -> four_two; % 四带二 + [{_, 3}, {_, 3}] -> airplane; % 飞机 + _ -> invalid + end. + %% 比较牌大小 compare_cards(Cards1, Cards2) -> case {get_card_type(Cards1), get_card_type(Cards2)} of diff --git a/src/decision_engine.erl b/src/decision_engine.erl deleted file mode 100644 index b29ff7f..0000000 --- a/src/decision_engine.erl +++ /dev/null @@ -1,53 +0,0 @@ --module(decision_engine). --export([make_decision/3, evaluate_options/2, calculate_win_probability/2]). - -make_decision(GameState, AIState, Options) -> - % 深度评估每个选项 - EvaluatedOptions = deep_evaluate_options(Options, GameState, AIState), - - % 计算胜率 - WinProbabilities = calculate_win_probabilities(EvaluatedOptions, GameState), - - % 风险评估 - RiskAnalysis = analyze_risks(EvaluatedOptions, GameState), - - % 综合决策 - BestOption = select_optimal_option(EvaluatedOptions, WinProbabilities, RiskAnalysis), - - % 应用最终决策 - apply_decision(BestOption, GameState, AIState). - -deep_evaluate_options(Options, GameState, AIState) -> - lists:map( - fun(Option) -> - % 深度搜索评估 - SearchResult = monte_carlo_search(Option, GameState, 1000), - - % 策略评估 - StrategyScore = strategy_optimizer:evaluate_strategy(Option, GameState), - - % 对手反应预测 - OpponentReaction = opponent_modeling:predict_play(AIState#ai_state.opponent_model, GameState), - - % 综合评分 - {Option, calculate_comprehensive_score(SearchResult, StrategyScore, OpponentReaction)} - end, - Options - ). - -calculate_win_probability(Option, GameState) -> - % 基于当前局势分析 - SituationScore = analyze_situation_score(GameState), - - % 基于牌型分析 - PatternScore = analyze_pattern_strength(Option), - - % 基于对手模型 - OpponentScore = analyze_opponent_factors(GameState), - - % 计算综合胜率 - calculate_combined_probability([ - {SituationScore, 0.4}, - {PatternScore, 0.3}, - {OpponentScore, 0.3} - ]). \ No newline at end of file diff --git a/src/deep_learning.erl b/src/deep_learning.erl deleted file mode 100644 index 00dceb8..0000000 --- a/src/deep_learning.erl +++ /dev/null @@ -1,104 +0,0 @@ --module(deep_learning). --behaviour(gen_server). - --export([start_link/0, init/1, handle_call/3, handle_cast/2, handle_info/2, terminate/2, code_change/3]). --export([train_network/2, predict/2, update_network/2, get_network_stats/1]). - --record(state, { - networks = #{}, % Map: NetworkName -> NetworkData - training_queue = [], % 训练队列 - batch_size = 32, % 批次大小 - learning_rate = 0.001 % 学习率 -}). - --record(network, { - layers = [], % 网络层结构 - weights = #{}, % 权重 - biases = #{}, % 偏置 - activation = relu, % 激活函数 - optimizer = adam, % 优化器 - loss_history = [], % 损失历史 - accuracy_history = [] % 准确率历史 -}). - -%% API -start_link() -> - gen_server:start_link({local, ?MODULE}, ?MODULE, [], []). - -train_network(NetworkName, TrainingData) -> - gen_server:call(?MODULE, {train, NetworkName, TrainingData}). - -predict(NetworkName, Input) -> - gen_server:call(?MODULE, {predict, NetworkName, Input}). - -update_network(NetworkName, Gradients) -> - gen_server:cast(?MODULE, {update, NetworkName, Gradients}). - -get_network_stats(NetworkName) -> - gen_server:call(?MODULE, {get_stats, NetworkName}). - -%% 内部函数 - -initialize_network(LayerSizes) -> - Layers = create_layers(LayerSizes), - Weights = initialize_weights(Layers), - Biases = initialize_biases(Layers), - #network{ - layers = Layers, - weights = Weights, - biases = Biases - }. - -create_layers(Sizes) -> - lists:map(fun(Size) -> - #{size => Size, type => dense} - end, Sizes). - -initialize_weights(Layers) -> - lists:foldl( - fun(Layer, Acc) -> - Size = maps:get(size, Layer), - W = random_matrix(Size, Size), - maps:put(Size, W, Acc) - end, - #{}, - Layers - ). - -random_matrix(Rows, Cols) -> - matrix:new(Rows, Cols, fun() -> rand:uniform() - 0.5 end). - -forward_propagation(Input, Network) -> - #network{layers = Layers, weights = Weights, biases = Biases} = Network, - lists:foldl( - fun(Layer, Acc) -> - W = maps:get(maps:get(size, Layer), Weights), - B = maps:get(maps:get(size, Layer), Biases), - Z = matrix:add(matrix:multiply(W, Acc), B), - activate(Z, Network#network.activation) - end, - Input, - Layers - ). - -backward_propagation(Network, Input, Target) -> - % 实现反向传播算法 - {Gradients, Loss} = calculate_gradients(Network, Input, Target), - {Network#network{ - weights = update_weights(Network#network.weights, Gradients), - loss_history = [Loss | Network#network.loss_history] - }, Loss}. - -activate(Z, relu) -> - matrix:map(fun(X) -> max(0, X) end, Z); -activate(Z, sigmoid) -> - matrix:map(fun(X) -> 1 / (1 + math:exp(-X)) end, Z); -activate(Z, tanh) -> - matrix:map(fun(X) -> math:tanh(X) end, Z). - -optimize(Network, Gradients, adam) -> - % 实现Adam优化器 - update_adam(Network, Gradients); -optimize(Network, Gradients, sgd) -> - % 实现随机梯度下降 - update_sgd(Network, Gradients). \ No newline at end of file diff --git a/src/doudizhu_ai.erl b/src/doudizhu_ai.erl new file mode 100644 index 0000000..a18b909 --- /dev/null +++ b/src/doudizhu_ai.erl @@ -0,0 +1,277 @@ +-module(doudizhu_ai). +-behaviour(gen_server). + +-export([ + start_link/1, + make_decision/2, + analyze_cards/1, + update_game_state/2 +]). + +-export([ + init/1, + handle_call/3, + handle_cast/2, + handle_info/2, + terminate/2, + code_change/3 +]). + +-include("card_types.hrl"). + +-record(state, { + player_id, % AI 玩家ID + role, % dizhu | nongmin (地主或农民) + known_cards = [], % 已知的牌 + hand_cards = [], % 手牌 + played_cards = [], % 已打出的牌 + other_players = [], % 其他玩家信息 + game_history = [], % 游戏历史 + strategy_cache = #{} % 策略缓存 +}). + +%% API 函数 +start_link(PlayerId) -> + gen_server:start_link({local, ?MODULE}, ?MODULE, [PlayerId], []). + +make_decision(GameState, Options) -> + gen_server:call(?MODULE, {make_decision, GameState, Options}). + +analyze_cards(Cards) -> + gen_server:call(?MODULE, {analyze_cards, Cards}). + +update_game_state(Event, Data) -> + gen_server:cast(?MODULE, {update_game_state, Event, Data}). + +%% Callback 函数 +init([PlayerId]) -> + {ok, #state{player_id = PlayerId}}. + +handle_call({make_decision, GameState, Options}, _From, State) -> + {Decision, NewState} = calculate_best_move(GameState, Options, State), + {reply, Decision, NewState}; + +handle_call({analyze_cards, Cards}, _From, State) -> + Analysis = perform_cards_analysis(Cards, State), + {reply, Analysis, State}; + +handle_call(_Request, _From, State) -> + {reply, ok, State}. + +handle_cast({update_game_state, Event, Data}, State) -> + NewState = update_state(Event, Data, State), + {noreply, NewState}; + +handle_cast(_Msg, State) -> + {noreply, State}. + +handle_info(_Info, State) -> + {noreply, State}. + +terminate(_Reason, _State) -> + ok. + +code_change(_OldVsn, State, _Extra) -> + {ok, State}. + +%% 内部函数 + +calculate_best_move(GameState, Options, State) -> + % 1. 分析当前局势 + Situation = analyze_situation(GameState, State), + % 2. 生成可能的动作 + PossibleMoves = generate_possible_moves(GameState, Options, State), + % 3. 评估每个动作 + ScoredMoves = evaluate_moves(PossibleMoves, Situation, State), + % 4. 选择最佳动作 + {BestMove, Score} = select_best_move(ScoredMoves), + % 5. 更新状态 + NewState = update_strategy_state(BestMove, Score, State), + {BestMove, NewState}. + +analyze_situation(GameState, State) -> + #situation{ + game_stage = determine_game_stage(GameState), + hand_strength = evaluate_hand_strength(State#state.hand_cards), + control_level = calculate_control_level(GameState, State), + winning_probability = estimate_winning_probability(GameState, State) + }. + +%% 确定游戏阶段 +determine_game_stage(GameState) -> + CardsLeft = count_remaining_cards(GameState), + cond do + CardsLeft > 15 -> early_game; + CardsLeft > 8 -> mid_game; + true -> end_game + end. + +%% 评估手牌强度 +evaluate_hand_strength(Cards) -> + % 分析手牌组成 + Components = analyze_card_components(Cards), + % 计算基础分数 + BaseScore = calculate_base_score(Components), + % 计算组合价值 + ComboScore = calculate_combo_value(Components), + % 计算控制牌价值 + ControlScore = calculate_control_value(Components), + % 返回综合评分 + #hand_strength{ + base_score = BaseScore, + combo_score = ComboScore, + control_score = ControlScore, + total_score = BaseScore + ComboScore + ControlScore + }. + +%% 分析牌型组件 +analyze_card_components(Cards) -> + % 对牌进行分组 + Groups = group_cards_by_value(Cards), + % 识别各种牌型 + Singles = find_singles(Groups), + Pairs = find_pairs(Groups), + Triples = find_triples(Groups), + Bombs = find_bombs(Groups), + Sequences = find_sequences(Groups), + % 返回组件分析结果 + #components{ + singles = Singles, + pairs = Pairs, + triples = Triples, + bombs = Bombs, + sequences = Sequences + }. + +%% 生成可能的出牌选择 +generate_possible_moves(GameState, Options, State) -> + LastPlay = get_last_play(GameState), + HandCards = State#state.hand_cards, + case LastPlay of + none -> + generate_leading_moves(HandCards); + {Type, Value, _Cards} -> + generate_following_moves(HandCards, Type, Value) + end. + +%% 评估每个可能的动作 +evaluate_moves(Moves, Situation, State) -> + lists:map(fun(Move) -> + Score = calculate_move_score(Move, Situation, State), + {Move, Score} + end, Moves). + +%% 计算单个动作的得分 +calculate_move_score(Move, Situation, State) -> + BaseScore = calculate_base_move_score(Move), + PositionScore = calculate_position_score(Move, Situation), + StrategyScore = calculate_strategy_score(Move, Situation, State), + RiskScore = calculate_risk_score(Move, Situation, State), + + BaseScore * 0.4 + + PositionScore * 0.2 + + StrategyScore * 0.3 + + RiskScore * 0.1. + +%% 选择最佳动作 +select_best_move(ScoredMoves) -> + lists:foldl(fun + ({Move, Score}, {BestMove, BestScore}) when Score > BestScore -> + {Move, Score}; + (_, Current) -> + Current + end, {pass, 0}, ScoredMoves). + +%% 更新策略状态 +update_strategy_state(Move, Score, State) -> + NewCache = update_strategy_cache(Move, Score, State#state.strategy_cache), + State#state{strategy_cache = NewCache}. + +%% 生成开局动作 +generate_leading_moves(Cards) -> + % 基于手牌生成所有可能的出牌选择 + Components = analyze_card_components(Cards), + % 生成不同类型的出牌 + Singles = generate_single_moves(Components), + Pairs = generate_pair_moves(Components), + Triples = generate_triple_moves(Components), + Sequences = generate_sequence_moves(Components), + Bombs = generate_bomb_moves(Components), + % 合并所有可能的出牌 + Singles ++ Pairs ++ Triples ++ Sequences ++ Bombs. + +%% 生成跟牌动作 +generate_following_moves(Cards, Type, MinValue) -> + % 根据上家出牌类型生成可能的跟牌 + ValidMoves = find_valid_moves(Cards, Type, MinValue), + % 添加炸弹和火箭选项 + SpecialMoves = find_special_moves(Cards), + ValidMoves ++ SpecialMoves. + +%% 估算胜率 +estimate_winning_probability(GameState, State) -> + % 考虑多个因素计算胜率 + HandStrength = evaluate_hand_strength(State#state.hand_cards), + Position = evaluate_position(GameState), + RemainingCards = analyze_remaining_cards(GameState, State), + ControlFactor = calculate_control_factor(GameState, State), + + BaseProb = calculate_base_probability(HandStrength, Position), + AdjustedProb = adjust_probability(BaseProb, RemainingCards, ControlFactor), + + clamp(AdjustedProb, 0.0, 1.0). + +%% 计算控制因子 +calculate_control_factor(GameState, State) -> + ControlCards = count_control_cards(State#state.hand_cards), + TotalCards = count_total_cards(GameState), + RemainingCards = count_remaining_cards(GameState), + + ControlRatio = ControlCards / max(1, RemainingCards), + PositionBonus = calculate_position_bonus(GameState, State), + + ControlRatio * PositionBonus. + +%% 分析剩余牌 +analyze_remaining_cards(GameState, State) -> + PlayedCards = get_played_cards(GameState), + KnownCards = State#state.known_cards, + AllCards = generate_full_deck(), + + RemainingCards = AllCards -- (PlayedCards ++ KnownCards), + analyze_card_distribution(RemainingCards). + +%% 更新游戏状态 +update_state(Event, Data, State) -> + case Event of + play_cards -> + update_after_play(Data, State); + receive_cards -> + update_after_receive(Data, State); + game_over -> + update_after_game_over(Data, State); + _ -> + State + end. + +%% 实用函数 + +clamp(Value, Min, Max) -> + min(Max, max(Min, Value)). + +%% 计算基础概率 +calculate_base_probability(HandStrength, Position) -> + BaseProb = HandStrength#hand_strength.total_score / 100, + PositionMod = case Position of + first -> 1.2; + middle -> 1.0; + last -> 0.8 + end, + BaseProb * PositionMod. + +%% 调整概率 +adjust_probability(BaseProb, RemainingCards, ControlFactor) -> + RemainingMod = calculate_remaining_modifier(RemainingCards), + ControlMod = calculate_control_modifier(ControlFactor), + + BaseProb * RemainingMod * ControlMod. \ No newline at end of file diff --git a/src/doudizhu_ai_strategy.erl b/src/doudizhu_ai_strategy.erl new file mode 100644 index 0000000..5255c2f --- /dev/null +++ b/src/doudizhu_ai_strategy.erl @@ -0,0 +1,399 @@ +-module(doudizhu_ai_strategy). + +-export([ + evaluate_situation/2, + choose_strategy/2, + execute_strategy/3, + calculate_move_score/3, + analyze_hand_value/1 +]). + +-include("card_types.hrl"). + +-record(strategy_context, { + game_stage, % early_game | mid_game | end_game + role, % dizhu | nongmin + hand_strength, % 手牌强度评估 + control_level, % 控制能力评估 + winning_prob, % 胜率估计 + cards_remaining, % 剩余手牌数 + opponent_cards % 对手剩余牌数 +}). + +-record(hand_value, { + singles = [], % 单牌 + pairs = [], % 对子 + triples = [], % 三张 + sequences = [], % 顺子 + bombs = [], % 炸弹 + rockets = [] % 火箭 +}). + +%% 评估局势 +evaluate_situation(GameState, PlayerState) -> + HandCards = PlayerState#state.hand_cards, + OpponentCards = get_opponent_cards(GameState), + + Context = #strategy_context{ + game_stage = determine_game_stage(GameState), + role = PlayerState#state.role, + hand_strength = evaluate_hand_strength(HandCards), + control_level = calculate_control_level(GameState, PlayerState), + winning_prob = estimate_winning_probability(GameState, PlayerState), + cards_remaining = length(HandCards), + opponent_cards = OpponentCards + }. + +%% 选择策略 +choose_strategy(Context, Options) -> + BaseStrategy = choose_base_strategy(Context), + RefineStrategy = refine_strategy(BaseStrategy, Context, Options), + adjust_strategy_for_endgame(RefineStrategy, Context). + +%% 执行策略 +execute_strategy(Strategy, GameState, PlayerState) -> + HandCards = PlayerState#state.hand_cards, + LastPlay = get_last_play(GameState), + + case Strategy of + aggressive -> + execute_aggressive_strategy(HandCards, LastPlay, GameState); + conservative -> + execute_conservative_strategy(HandCards, LastPlay, GameState); + control -> + execute_control_strategy(HandCards, LastPlay, GameState); + explosive -> + execute_explosive_strategy(HandCards, LastPlay, GameState) + end. + +%% 内部函数 + +%% 选择基础策略 +choose_base_strategy(Context) -> + case {Context#strategy_context.game_stage, Context#strategy_context.role} of + {early_game, dizhu} when Context#strategy_context.hand_strength >= 0.7 -> + aggressive; + {early_game, nongmin} when Context#strategy_context.hand_strength >= 0.8 -> + control; + {mid_game, _} when Context#strategy_context.winning_prob >= 0.7 -> + control; + {end_game, _} when Context#strategy_context.cards_remaining =< 5 -> + explosive; + _ -> + conservative + end. + +%% 细化策略 +refine_strategy(BaseStrategy, Context, Options) -> + case {BaseStrategy, Context#strategy_context.control_level} of + {conservative, Control} when Control >= 0.8 -> + control; + {aggressive, Control} when Control =< 0.3 -> + conservative; + {Strategy, _} -> + Strategy + end. + +%% 调整终盘策略 +adjust_strategy_for_endgame(Strategy, Context) -> + case Context#strategy_context.cards_remaining of + N when N =< 3 -> + case Context#strategy_context.winning_prob of + P when P >= 0.8 -> aggressive; + P when P =< 0.2 -> explosive; + _ -> Strategy + end; + _ -> + Strategy + end. + +%% 执行进攻策略 +execute_aggressive_strategy(HandCards, LastPlay, GameState) -> + HandValue = analyze_hand_value(HandCards), + case LastPlay of + none -> + choose_aggressive_lead(HandValue); + Play -> + choose_aggressive_follow(HandValue, Play) + end. + +%% 执行保守策略 +execute_conservative_strategy(HandCards, LastPlay, GameState) -> + HandValue = analyze_hand_value(HandCards), + case LastPlay of + none -> + choose_conservative_lead(HandValue); + Play -> + choose_conservative_follow(HandValue, Play) + end. + +%% 执行控制策略 +execute_control_strategy(HandCards, LastPlay, GameState) -> + HandValue = analyze_hand_value(HandCards), + case LastPlay of + none -> + choose_control_lead(HandValue); + Play -> + choose_control_follow(HandValue, Play) + end. + +%% 执行爆发策略 +execute_explosive_strategy(HandCards, LastPlay, GameState) -> + HandValue = analyze_hand_value(HandCards), + case LastPlay of + none -> + choose_explosive_lead(HandValue); + Play -> + choose_explosive_follow(HandValue, Play) + end. + +%% 选择进攻性首出 +choose_aggressive_lead(HandValue) -> + case find_best_lead_combination(HandValue) of + {ok, Move} -> Move; + none -> find_single_card(HandValue) + end. + +%% 选择进攻性跟牌 +choose_aggressive_follow(HandValue, LastPlay) -> + case find_minimum_greater_combination(HandValue, LastPlay) of + {ok, Move} -> Move; + none -> + case should_use_bomb(HandValue, LastPlay) of + true -> find_smallest_bomb(HandValue); + false -> pass + end + end. + +%% 选择保守性首出 +choose_conservative_lead(HandValue) -> + case find_safe_lead(HandValue) of + {ok, Move} -> Move; + none -> find_smallest_single(HandValue) + end. + +%% 选择保守性跟牌 +choose_conservative_follow(HandValue, LastPlay) -> + case find_safe_follow(HandValue, LastPlay) of + {ok, Move} -> Move; + none -> pass + end. + +%% 选择控制性首出 +choose_control_lead(HandValue) -> + case find_control_combination(HandValue) of + {ok, Move} -> Move; + none -> find_tempo_play(HandValue) + end. + +%% 选择控制性跟牌 +choose_control_follow(HandValue, LastPlay) -> + case should_maintain_control(HandValue, LastPlay) of + true -> find_minimum_greater_combination(HandValue, LastPlay); + false -> pass + end. + +%% 选择爆发性首出 +choose_explosive_lead(HandValue) -> + case find_strongest_combination(HandValue) of + {ok, Move} -> Move; + none -> find_any_playable(HandValue) + end. + +%% 选择爆发性跟牌 +choose_explosive_follow(HandValue, LastPlay) -> + case find_crushing_play(HandValue, LastPlay) of + {ok, Move} -> Move; + none -> pass + end. + +%% 分析手牌价值 +analyze_hand_value(Cards) -> + GroupedCards = group_cards(Cards), + #hand_value{ + singles = find_singles(GroupedCards), + pairs = find_pairs(GroupedCards), + triples = find_triples(GroupedCards), + sequences = find_sequences(GroupedCards), + bombs = find_bombs(GroupedCards), + rockets = find_rockets(GroupedCards) + }. + +%% 计算出牌得分 +calculate_move_score(Move, Context, LastPlay) -> + BaseScore = calculate_base_score(Move), + TempoScore = calculate_tempo_score(Move, Context), + ControlScore = calculate_control_score(Move, Context), + EfficiencyScore = calculate_efficiency_score(Move, Context), + + WeightedScore = BaseScore * 0.3 + + TempoScore * 0.2 + + ControlScore * 0.3 + + EfficiencyScore * 0.2, + + adjust_score_for_context(WeightedScore, Move, Context, LastPlay). + +%% 辅助函数 + +group_cards(Cards) -> + lists:foldl(fun(Card, Acc) -> + {Value, _} = Card, + maps:update_with(Value, fun(List) -> [Card|List] end, [Card], Acc) + end, maps:new(), Cards). + +find_singles(GroupedCards) -> + maps:fold(fun(Value, [Card], Acc) -> + [{Value, Card}|Acc]; + (_, _, Acc) -> Acc + end, [], GroupedCards). + +find_pairs(GroupedCards) -> + maps:fold(fun(Value, Cards, Acc) -> + case length(Cards) >= 2 of + true -> [{Value, lists:sublist(Cards, 2)}|Acc]; + false -> Acc + end + end, [], GroupedCards). + +find_triples(GroupedCards) -> + maps:fold(fun(Value, Cards, Acc) -> + case length(Cards) >= 3 of + true -> [{Value, lists:sublist(Cards, 3)}|Acc]; + false -> Acc + end + end, [], GroupedCards). + +find_sequences(GroupedCards) -> + Values = lists:sort(maps:keys(GroupedCards)), + find_consecutive_sequences(Values, GroupedCards, 5). + +find_bombs(GroupedCards) -> + maps:fold(fun(Value, Cards, Acc) -> + case length(Cards) >= 4 of + true -> [{Value, Cards}|Acc]; + false -> Acc + end + end, [], GroupedCards). + +find_rockets(GroupedCards) -> + case {maps:get(?CARD_JOKER_SMALL, GroupedCards, []), + maps:get(?CARD_JOKER_BIG, GroupedCards, [])} of + {[Small], [Big]} -> [{?CARD_JOKER_BIG, [Small, Big]}]; + _ -> [] + end. + +find_consecutive_sequences(Values, GroupedCards, MinLength) -> + find_sequences_of_length(Values, GroupedCards, MinLength, []). + +find_sequences_of_length([], _, _, Acc) -> Acc; +find_sequences_of_length([V|Rest], GroupedCards, MinLength, Acc) -> + case find_sequence_starting_at(V, Rest, GroupedCards, MinLength) of + {ok, Sequence} -> + find_sequences_of_length(Rest, GroupedCards, MinLength, + [Sequence|Acc]); + false -> + find_sequences_of_length(Rest, GroupedCards, MinLength, Acc) + end. + +find_sequence_starting_at(Start, Rest, GroupedCards, MinLength) -> + case collect_sequence(Start, Rest, GroupedCards, []) of + Seq when length(Seq) >= MinLength -> + {ok, {Start, Seq}}; + _ -> + false + end. + +collect_sequence(Current, [Next|Rest], GroupedCards, Acc) when Next =:= Current + 1 -> + case maps:get(Current, GroupedCards) of + Cards when length(Cards) >= 1 -> + collect_sequence(Next, Rest, GroupedCards, [hd(Cards)|Acc]); + _ -> + lists:reverse(Acc) + end; +collect_sequence(Current, _, GroupedCards, Acc) -> + case maps:get(Current, GroupedCards) of + Cards when length(Cards) >= 1 -> + lists:reverse([hd(Cards)|Acc]); + _ -> + lists:reverse(Acc) + end. + +calculate_base_score({Type, Value, _Cards}) -> + BaseScore = Value * 10, + TypeMultiplier = case Type of + ?CARD_TYPE_ROCKET -> 100; + ?CARD_TYPE_BOMB -> 80; + ?CARD_TYPE_STRAIGHT -> 40; + ?CARD_TYPE_STRAIGHT_PAIR -> 35; + ?CARD_TYPE_PLANE -> 30; + ?CARD_TYPE_THREE_TWO -> 25; + ?CARD_TYPE_THREE_ONE -> 20; + ?CARD_TYPE_THREE -> 15; + ?CARD_TYPE_PAIR -> 10; + ?CARD_TYPE_SINGLE -> 5 + end, + BaseScore * TypeMultiplier / 100. + +calculate_tempo_score(Move, Context) -> + case Context#strategy_context.game_stage of + early_game -> calculate_early_tempo(Move, Context); + mid_game -> calculate_mid_tempo(Move, Context); + end_game -> calculate_end_tempo(Move, Context) + end. + +calculate_control_score(Move, Context) -> + BaseControl = case Move of + {Type, _, _} when Type =:= ?CARD_TYPE_BOMB; + Type =:= ?CARD_TYPE_ROCKET -> 1.0; + {_, Value, _} when Value >= ?CARD_2 -> 0.8; + _ -> 0.5 + end, + BaseControl * Context#strategy_context.control_level. + +calculate_efficiency_score(Move, Context) -> + {Type, _, Cards} = Move, + CardsUsed = length(Cards), + RemainingCards = Context#strategy_context.cards_remaining - CardsUsed, + Efficiency = CardsUsed / max(1, Context#strategy_context.cards_remaining), + Efficiency * (1 + (20 - RemainingCards) / 20). + +adjust_score_for_context(Score, Move, Context, LastPlay) -> + case {Context#strategy_context.role, LastPlay} of + {dizhu, none} -> Score * 1.2; + {dizhu, _} -> Score * 1.1; + {nongmin, _} -> Score + end. + +calculate_early_tempo(Move, Context) -> + case Move of + {Type, _, _} when Type =:= ?CARD_TYPE_SINGLE; + Type =:= ?CARD_TYPE_PAIR -> + 0.7; + {Type, _, _} when Type =:= ?CARD_TYPE_STRAIGHT; + Type =:= ?CARD_TYPE_STRAIGHT_PAIR -> + 0.9; + _ -> + 0.5 + end. + +calculate_mid_tempo(Move, Context) -> + case Move of + {Type, _, _} when Type =:= ?CARD_TYPE_THREE_ONE; + Type =:= ?CARD_TYPE_THREE_TWO -> + 0.8; + {Type, _, _} when Type =:= ?CARD_TYPE_PLANE -> + 0.9; + _ -> + 0.6 + end. + +calculate_end_tempo(Move, Context) -> + case Move of + {Type, _, _} when Type =:= ?CARD_TYPE_BOMB; + Type =:= ?CARD_TYPE_ROCKET -> + 1.0; + {_, Value, _} when Value >= ?CARD_2 -> + 0.9; + _ -> + 0.7 + end. \ No newline at end of file diff --git a/src/doudizhu_ai_sup.erl b/src/doudizhu_ai_sup.erl new file mode 100644 index 0000000..e8b1419 --- /dev/null +++ b/src/doudizhu_ai_sup.erl @@ -0,0 +1,44 @@ +-module(doudizhu_ai_sup). +-behaviour(supervisor). + +-export([start_link/0]). +-export([init/1]). + +start_link() -> + supervisor:start_link({local, ?MODULE}, ?MODULE, []). + +init([]) -> + SupFlags = #{ + strategy => one_for_one, + intensity => 10, + period => 60 + }, + + Children = [ + #{ + id => ml_engine, + start => {ml_engine, start_link, []}, + restart => permanent, + shutdown => 5000, + type => worker, + modules => [ml_engine] + }, + #{ + id => training_system, + start => {training_system, start_link, []}, + restart => permanent, + shutdown => 5000, + type => worker, + modules => [training_system] + }, + #{ + id => visualization, + start => {visualization, start_link, []}, + restart => permanent, + shutdown => 5000, + type => worker, + modules => [visualization] + } + ], + + {ok, {SupFlags, Children}}. \ No newline at end of file diff --git a/src/game_core.erl b/src/game_core.erl index bfe3e66..3d7ffab 100644 --- a/src/game_core.erl +++ b/src/game_core.erl @@ -100,37 +100,191 @@ analyze_complex_pattern(Cards) -> false -> analyze_other_patterns(Cards) end. -%% 辅助函数 +%% 获取游戏状态 +get_game_state(GameId) -> + % 简单实现,实际应该从某种存储中获取 + {ok, #game_state{}}. -card_value({_, "A"}) -> 14; -card_value({_, "2"}) -> 15; -card_value({_, "小王"}) -> 16; -card_value({_, "大王"}) -> 17; -card_value({_, N}) when is_list(N) -> - try list_to_integer(N) - catch _:_ -> - case N of - "J" -> 11; - "Q" -> 12; - "K" -> 13 - end +%% 判断出牌是否有效 +is_valid_play(Cards, LastPlay) -> + % 如果是第一手牌,任何合法牌型都可以 + case LastPlay of + [] -> is_valid_card_type(Cards); + {_, LastCards} -> + % 检查牌型是否相同且大小是否更大 + {ok, Type1, Value1} = get_card_type(Cards), + {ok, Type2, Value2} = get_card_type(LastCards), + (Type1 =:= Type2 andalso Value1 > Value2) orelse + (Type1 =:= bomb andalso Type2 =/= bomb) end. -sort_cards(Cards) -> - lists:sort(fun(A, B) -> card_value(A) =< card_value(B) end, Cards). +%% 判断牌型是否合法 +is_valid_card_type(Cards) -> + case get_card_type(Cards) of + {ok, _, _} -> true; + _ -> false + end. + +%% 比较两手牌的大小 +compare_cards(Cards1, Cards2) -> + {ok, Type1, Value1} = get_card_type(Cards1), + {ok, Type2, Value2} = get_card_type(Cards2), + if + Type1 =:= bomb andalso Type2 =/= bomb -> greater; + Type2 =:= bomb andalso Type1 =/= bomb -> lesser; + Type1 =:= Type2 andalso Value1 > Value2 -> greater; + Type1 =:= Type2 andalso Value1 < Value2 -> lesser; + true -> incomparable + end. + +%% 分配角色 +assign_roles(PlayerCards, LandlordIdx) -> + lists:map( + fun({Idx, {Pid, Cards}}) -> + Role = case Idx of + LandlordIdx -> landlord; + _ -> farmer + end, + {Pid, Cards, Role} + end, + lists:zip(lists:seq(1, length(PlayerCards)), PlayerCards) + ). + +%% 更新游戏状态 +update_game_state(GameState, PlayerPid, Cards) -> + % 更新玩家手牌 + UpdatedPlayers = lists:map( + fun({Pid, PlayerCards, Role}) when Pid =:= PlayerPid -> + {Pid, PlayerCards -- Cards, Role}; + (Player) -> Player + end, + GameState#game_state.players + ), + + % 更新出牌记录和当前玩家 + NextPlayer = get_next_player(GameState, PlayerPid), + GameState#game_state{ + players = UpdatedPlayers, + current_player = NextPlayer, + last_play = {PlayerPid, Cards}, + played_cards = [{PlayerPid, Cards} | GameState#game_state.played_cards] + }. + +%% 获取下一个玩家 +get_next_player(GameState, CurrentPid) -> + Players = GameState#game_state.players, + PlayerPids = [Pid || {Pid, _, _} <- Players], + CurrentIdx = get_player_index(PlayerPids, CurrentPid), + lists:nth(1 + (CurrentIdx rem length(PlayerPids)), PlayerPids). + +%% 获取玩家索引 +get_player_index(PlayerPids, TargetPid) -> + {Idx, _} = lists:foldl( + fun(Pid, {Idx, Found}) -> + case Pid =:= TargetPid of + true -> {Idx, Idx}; + false -> {Idx + 1, Found} + end + end, + {1, not_found}, + PlayerPids + ), + Idx. + +%% 检查游戏是否结束 +check_game_end(GameState) -> + case lists:any( + fun({_, Cards, _}) -> length(Cards) =:= 0 end, + GameState#game_state.players + ) of + true -> + Winner = lists:keyfind(0, 2, GameState#game_state.players), + {game_over, Winner, GameState#game_state{stage = finished}}; + false -> + {continue, GameState} + end. + +%% 获取玩家手牌 +get_player_cards(GameState, PlayerPid) -> + case lists:keyfind(PlayerPid, 1, GameState#game_state.players) of + {_, Cards, _} -> {ok, Cards}; + false -> {error, player_not_found} + end. + +%% 检查玩家是否有这些牌 +has_cards(PlayerCards, CardsToPlay) -> + lists:all( + fun(Card) -> + Count = count_card(PlayerCards, Card), + CountToPlay = count_card(CardsToPlay, Card), + Count >= CountToPlay + end, + lists:usort(CardsToPlay) + ). + +%% 计算特定牌的数量 +count_card(Cards, TargetCard) -> + length([Card || Card <- Cards, Card =:= TargetCard]). + +%% 分析四带二 +analyze_four_with_two(Cards) -> + Grouped = group_cards(Cards), + case maps:get(4, Grouped, []) of + [Value] -> + case length(Cards) of + 6 -> {four_with_two, card_value({any, Value})}; + _ -> invalid + end; + _ -> invalid + end. +%% 分组牌 group_cards(Cards) -> lists:foldl( - fun({_, N}, Acc) -> - maps:update_with(N, fun(L) -> [N|L] end, [N], Acc) + fun({_, Number}, Acc) -> + Count = maps:get(Number, Acc, 0) + 1, + Acc#{Number => Count} end, #{}, Cards ). +%% 获取最高牌值 +highest_card_value(Cards) -> + lists:max([card_value(Card) || Card <- Cards]). + +%% 牌值转换 +card_value({_, "3"}) -> 3; +card_value({_, "4"}) -> 4; +card_value({_, "5"}) -> 5; +card_value({_, "6"}) -> 6; +card_value({_, "7"}) -> 7; +card_value({_, "8"}) -> 8; +card_value({_, "9"}) -> 9; +card_value({_, "10"}) -> 10; +card_value({_, "J"}) -> 11; +card_value({_, "Q"}) -> 12; +card_value({_, "K"}) -> 13; +card_value({_, "A"}) -> 14; +card_value({_, "2"}) -> 15; +card_value({_, "小王"}) -> 16; +card_value({_, "大王"}) -> 17; +card_value({any, Value}) -> card_value({"♠", Value}). + +%% 检查是否为顺子 is_straight(Cards) -> - Values = [card_value(C) || C <- Cards], - Sorted = lists:sort(Values), - length(Sorted) >= 5 andalso - lists:all(fun({A, B}) -> B - A =:= 1 end, - lists:zip(Sorted, tl(Sorted))). \ No newline at end of file + Values = [card_value(Card) || Card <- Cards], + SortedValues = lists:sort(Values), + length(SortedValues) >= 5 andalso + lists:max(SortedValues) =< 14 andalso % A以下的牌才能组成顺子 + SortedValues =:= lists:seq(hd(SortedValues), lists:last(SortedValues)). + +%% 分析其他牌型 +analyze_other_patterns(Cards) -> + % 实现其他复杂牌型的分析,如飞机、连对等 + % 简单实现,实际应该有更复杂的逻辑 + invalid. + +%% 排序牌 +sort_cards(Cards) -> + lists:sort(fun(A, B) -> card_value(A) =< card_value(B) end, Cards). \ No newline at end of file diff --git a/src/game_manager.erl b/src/game_manager.erl index 5dcd25a..130213e 100644 --- a/src/game_manager.erl +++ b/src/game_manager.erl @@ -48,4 +48,119 @@ handle_play(GameManagerState, Play) -> }; false -> {error, invalid_play} - end. \ No newline at end of file + end. + +%% 结束游戏 +end_game(GameManagerState) -> + % 计算最终得分 + FinalScores = calculate_final_scores(GameManagerState), + + % 更新玩家统计数据 + update_player_stats(GameManagerState, FinalScores), + + % 保存游戏记录 + save_game_record(GameManagerState), + + % 返回游戏结果 + {game_ended, FinalScores}. + +%% 初始化AI玩家 +initialize_ai_players() -> + % 创建不同类型的AI玩家 + BasicAI = ai_player:init(basic), + AdvancedAI = advanced_ai_player:init(advanced), + + % 返回AI玩家列表 + [BasicAI, AdvancedAI]. + +%% 生成游戏ID +generate_game_id() -> + % 使用时间戳和随机数生成唯一ID + {{Year, Month, Day}, {Hour, Minute, Second}} = calendar:local_time(), + Random = rand:uniform(1000), + list_to_binary(io_lib:format("~4..0B~2..0B~2..0B~2..0B~2..0B~2..0B~4..0B", + [Year, Month, Day, Hour, Minute, Second, Random])). + +%% 游戏循环 +game_loop(GameManagerState) -> + % 检查游戏是否结束 + case is_game_over(GameManagerState#game_manager_state.current_state) of + true -> + % 游戏结束,返回结果 + end_game(GameManagerState); + false -> + % 获取当前玩家 + CurrentPlayer = get_current_player(GameManagerState), + + % 如果是AI玩家,自动出牌 + case is_ai_player(CurrentPlayer, GameManagerState) of + true -> + AIPlay = get_ai_play(CurrentPlayer, GameManagerState), + NewState = handle_play(GameManagerState, AIPlay), + game_loop(NewState); + false -> + % 如果是人类玩家,返回当前状态等待输入 + {waiting_for_player, GameManagerState} + end + end. + +%% 更新AI玩家 +update_ai_players(AIPlayers, Play) -> + % 为每个AI玩家更新游戏信息 + lists:map( + fun(AIPlayer) -> + update_ai_knowledge(AIPlayer, Play) + end, + AIPlayers + ). + +%% 辅助函数 + +%% 计算最终得分 +calculate_final_scores(GameManagerState) -> + % 根据游戏规则计算得分 + % 简单实现,实际应该有更复杂的计分逻辑 + #{ + winner => get_winner(GameManagerState), + scores => get_player_scores(GameManagerState) + }. + +%% 更新玩家统计数据 +update_player_stats(GameManagerState, FinalScores) -> + % 更新玩家的胜率、得分等统计信息 + % 简单实现 + ok. + +%% 保存游戏记录 +save_game_record(GameManagerState) -> + % 将游戏记录保存到数据库或文件 + % 简单实现 + ok. + +%% 检查游戏是否结束 +is_game_over(GameState) -> + % 检查是否有玩家的牌已经出完 + % 简单实现 + false. + +%% 获取当前玩家 +get_current_player(GameManagerState) -> + % 从游戏状态中获取当前玩家 + GameState = GameManagerState#game_manager_state.current_state, + GameState#game_state.current_player. + +%% 判断是否为AI玩家 +is_ai_player(Player, GameManagerState) -> + % 检查玩家是否在AI玩家列表中 + lists:member(Player, GameManagerState#game_manager_state.ai_players). + +%% 获取AI出牌 +get_ai_play(AIPlayer, GameManagerState) -> + % 调用AI模块获取出牌决策 + GameState = GameManagerState#game_manager_state.current_state, + ai_core:make_decision(AIPlayer, GameState). + +%% 更新AI知识 +update_ai_knowledge(AIPlayer, Play) -> + % 更新AI对游戏的认知 + ai_core:update_knowledge(AIPlayer, Play). \ No newline at end of file diff --git a/src/matrix.erl b/src/matrix.erl deleted file mode 100644 index cb445e1..0000000 --- a/src/matrix.erl +++ /dev/null @@ -1,112 +0,0 @@ --module(matrix). --export([new/3, multiply/2, add/2, subtract/2, transpose/1, map/2]). --export([from_list/1, to_list/1, get/3, set/4, shape/1]). - --record(matrix, { - rows, - cols, - data -}). - -new(Rows, Cols, InitFun) when is_integer(Rows), is_integer(Cols), Rows > 0, Cols > 0 -> - Data = array:new(Rows * Cols, {default, 0.0}), - Data2 = case is_function(InitFun) of - true -> - lists:foldl( - fun(I, Acc) -> - lists:foldl( - fun(J, Acc2) -> - array:set(I * Cols + J, InitFun(), Acc2) - end, - Acc, - lists:seq(0, Cols-1) - ) - end, - Data, - lists:seq(0, Rows-1) - ); - false -> - array:set_value(InitFun, Data) - end, - #matrix{rows = Rows, cols = Cols, data = Data2}. - -multiply(#matrix{rows = M, cols = N, data = Data1}, - #matrix{rows = N, cols = P, data = Data2}) -> - Result = array:new(M * P, {default, 0.0}), - ResultData = lists:foldl( - fun(I, Acc1) -> - lists:foldl( - fun(J, Acc2) -> - Sum = lists:sum([ - array:get(I * N + K, Data1) * array:get(K * P + J, Data2) - || K <- lists:seq(0, N-1) - ]), - array:set(I * P + J, Sum, Acc2) - end, - Acc1, - lists:seq(0, P-1) - ) - end, - Result, - lists:seq(0, M-1) - ), - #matrix{rows = M, cols = P, data = ResultData}. - -add(#matrix{rows = R, cols = C, data = Data1}, - #matrix{rows = R, cols = C, data = Data2}) -> - NewData = array:map( - fun(I, V) -> V + array:get(I, Data2) end, - Data1 - ), - #matrix{rows = R, cols = C, data = NewData}. - -subtract(#matrix{rows = R, cols = C, data = Data1}, - #matrix{rows = R, cols = C, data = Data2}) -> - NewData = array:map( - fun(I, V) -> V - array:get(I, Data2) end, - Data1 - ), - #matrix{rows = R, cols = C, data = NewData}. - -transpose(#matrix{rows = R, cols = C, data = Data}) -> - NewData = array:new(R * C, {default, 0.0}), - TransposedData = lists:foldl( - fun(I, Acc1) -> - lists:foldl( - fun(J, Acc2) -> - array:set(J * R + I, array:get(I * C + J, Data), Acc2) - end, - Acc1, - lists:seq(0, C-1) - ) - end, - NewData, - lists:seq(0, R-1) - ), - #matrix{rows = C, cols = R, data = TransposedData}. - -map(Fun, #matrix{rows = R, cols = C, data = Data}) -> - NewData = array:map(fun(_, V) -> Fun(V) end, Data), - #matrix{rows = R, cols = C, data = NewData}. - -from_list(List) when is_list(List) -> - Rows = length(List), - Cols = length(hd(List)), - Data = array:from_list(lists:flatten(List)), - #matrix{rows = Rows, cols = Cols, data = Data}. - -to_list(#matrix{rows = R, cols = C, data = Data}) -> - [ - [array:get(I * C + J, Data) || J <- lists:seq(0, C-1)] - || I <- lists:seq(0, R-1) - ]. - -get(#matrix{cols = C, data = Data}, Row, Col) -> - array:get(Row * C + Col, Data). - -set(#matrix{cols = C, data = Data} = M, Row, Col, Value) -> - NewData = array:set(Row * C + Col, Value, Data), - M#matrix{data = NewData}. - -shape(#matrix{rows = R, cols = C}) -> - {R, C}. \ No newline at end of file diff --git a/src/ml_engine.erl b/src/ml_engine.erl index d716b19..1b9f0a6 100644 --- a/src/ml_engine.erl +++ b/src/ml_engine.erl @@ -1,157 +1,663 @@ -module(ml_engine). -behaviour(gen_server). --export([start_link/0, init/1, handle_call/3, handle_cast/2, handle_info/2, terminate/2, code_change/3]). --export([train/2, predict/2, update_model/3, get_model_stats/1]). +%% API exports +-export([ + start_link/0, + train/2, + predict/2, + update_model/2, + get_model_state/0, + save_model/1, + load_model/1, + add_training_sample/1 +]). + +%% gen_server callbacks +-export([ + init/1, + handle_call/3, + handle_cast/2, + handle_info/2, + terminate/2, + code_change/3 +]). + +-include("card_types.hrl"). -record(state, { - models = #{}, % Map: ModelName -> ModelData - training_data = #{}, % Map: ModelName -> TrainingData - model_stats = #{}, % Map: ModelName -> Stats - last_update = undefined + model, % 当前模型 + model_version = 1, % 模型版本 + training_data = [], % 训练数据集 + hyperparameters = #{}, % 超参数 + feature_config = #{}, % 特征配置 + last_update, % 最后更新时间 + performance_metrics = [] % 性能指标历史 }). --record(model_data, { - weights = #{}, % 模型权重 - features = [], % 特征列表 - learning_rate = 0.01, % 学习率 - iterations = 0, % 训练迭代次数 - accuracy = 0.0 % 模型准确率 +-record(model, { + weights = #{}, % 模型权重 + biases = #{}, % 偏置项 + layers = [], % 网络层配置 + activation_functions = #{}, % 激活函数配置 + normalization_params = #{}, % 归一化参数 + feature_importance = #{}, % 特征重要性 + last_train_error = 0.0 % 最后一次训练误差 }). -%% API +%% API 函数 start_link() -> gen_server:start_link({local, ?MODULE}, ?MODULE, [], []). -train(ModelName, TrainingData) -> - gen_server:call(?MODULE, {train, ModelName, TrainingData}). +train(Data, Options) -> + gen_server:call(?MODULE, {train, Data, Options}, infinity). + +predict(Features, Options) -> + gen_server:call(?MODULE, {predict, Features, Options}). -predict(ModelName, Features) -> - gen_server:call(?MODULE, {predict, ModelName, Features}). +update_model(NewModel, Options) -> + gen_server:cast(?MODULE, {update_model, NewModel, Options}). -update_model(ModelName, NewData, Reward) -> - gen_server:cast(?MODULE, {update_model, ModelName, NewData, Reward}). +get_model_state() -> + gen_server:call(?MODULE, get_model_state). -get_model_stats(ModelName) -> - gen_server:call(?MODULE, {get_stats, ModelName}). +save_model(Filename) -> + gen_server:call(?MODULE, {save_model, Filename}). -%% Callbacks +load_model(Filename) -> + gen_server:call(?MODULE, {load_model, Filename}). + +add_training_sample(Sample) -> + gen_server:cast(?MODULE, {add_training_sample, Sample}). + +%% Callback 函数 init([]) -> - % 初始化各种策略模型 - Models = initialize_models(), - {ok, #state{models = Models, last_update = os:timestamp()}}. - -handle_call({train, ModelName, TrainingData}, _From, State) -> - {NewModel, Stats} = train_model(ModelName, TrainingData, State), - NewModels = maps:put(ModelName, NewModel, State#state.models), - NewStats = maps:put(ModelName, Stats, State#state.model_stats), - {reply, {ok, Stats}, State#state{models = NewModels, model_stats = NewStats}}; - -handle_call({predict, ModelName, Features}, _From, State) -> - case maps:get(ModelName, State#state.models, undefined) of - undefined -> - {reply, {error, model_not_found}, State}; - Model -> - Prediction = make_prediction(Model, Features), - {reply, {ok, Prediction}, State} + {ok, #state{ + model = initialize_model(), + last_update = os:timestamp(), + hyperparameters = default_hyperparameters(), + feature_config = default_feature_config() + }}. + +handle_call({train, Data, Options}, _From, State) -> + {Result, NewState} = do_train(Data, Options, State), + {reply, Result, NewState}; + +handle_call({predict, Features, Options}, _From, State) -> + {Prediction, NewState} = do_predict(Features, Options, State), + {reply, Prediction, NewState}; + +handle_call(get_model_state, _From, State) -> + {reply, get_current_model_state(State), State}; + +handle_call({save_model, Filename}, _From, State) -> + Result = do_save_model(State, Filename), + {reply, Result, State}; + +handle_call({load_model, Filename}, _From, State) -> + case do_load_model(Filename) of + {ok, NewState} -> {reply, ok, NewState}; + Error -> {reply, Error, State} end; -handle_call({get_stats, ModelName}, _From, State) -> - Stats = maps:get(ModelName, State#state.model_stats, undefined), - {reply, {ok, Stats}, State}. +handle_call(_Request, _From, State) -> + {reply, {error, unknown_call}, State}. + +handle_cast({update_model, NewModel, Options}, State) -> + {noreply, do_update_model(State, NewModel, Options)}; + +handle_cast({add_training_sample, Sample}, State) -> + {noreply, do_add_training_sample(State, Sample)}; -handle_cast({update_model, ModelName, NewData, Reward}, State) -> - Model = maps:get(ModelName, State#state.models), - UpdatedModel = update_model_weights(Model, NewData, Reward), - NewModels = maps:put(ModelName, UpdatedModel, State#state.models), - {noreply, State#state{models = NewModels}}. +handle_cast(_Msg, State) -> + {noreply, State}. + +handle_info(update_metrics, State) -> + {noreply, update_performance_metrics(State)}; + +handle_info(_Info, State) -> + {noreply, State}. + +terminate(_Reason, State) -> + save_final_state(State), + ok. + +code_change(_OldVsn, State, _Extra) -> + {ok, State}. %% 内部函数 -initialize_models() -> - Models = #{ - play_strategy => init_play_strategy_model(), - card_combination => init_card_combination_model(), - opponent_prediction => init_opponent_prediction_model(), - game_state_evaluation => init_game_state_model() - }, - Models. - -init_play_strategy_model() -> - #model_data{ - features = [ - remaining_cards, - opponent_cards, - current_position, - game_stage, - last_play_type, - has_control - ] +%% 初始化模型 +initialize_model() -> + #model{ + weights = initialize_weights(), + biases = initialize_biases(), + layers = default_layer_configuration(), + activation_functions = default_activation_functions(), + normalization_params = initialize_normalization_params() + }. + +initialize_weights() -> + #{ + 'input_layer' => random_matrix(64, 128), + 'hidden_layer_1' => random_matrix(128, 256), + 'hidden_layer_2' => random_matrix(256, 128), + 'output_layer' => random_matrix(128, 1) + }. + +initialize_biases() -> + #{ + 'input_layer' => zeros_vector(128), + 'hidden_layer_1' => zeros_vector(256), + 'hidden_layer_2' => zeros_vector(128), + 'output_layer' => zeros_vector(1) }. -init_card_combination_model() -> - #model_data{ - features = [ - card_count, - card_types, - sequence_length, - combination_value - ] +%% 默认配置 +default_hyperparameters() -> + #{ + learning_rate => 0.001, + batch_size => 32, + epochs => 100, + momentum => 0.9, + dropout_rate => 0.5, + l2_regularization => 0.01, + early_stopping_patience => 5 }. -init_opponent_prediction_model() -> - #model_data{ - features = [ - played_cards, - remaining_unknown, - player_position, - playing_pattern - ] +default_feature_config() -> + #{ + card_value_weight => 1.0, + card_type_weight => 0.8, + sequence_weight => 1.2, + combo_weight => 1.5, + position_weight => 0.7, + timing_weight => 0.9 }. -init_game_state_model() -> - #model_data{ - features = [ - cards_played, - cards_remaining, - player_positions, - game_control - ] +default_layer_configuration() -> + [ + {input, 64}, + {dense, 128, relu}, + {dropout, 0.5}, + {dense, 256, relu}, + {dropout, 0.5}, + {dense, 128, relu}, + {dense, 1, sigmoid} + ]. + +default_activation_functions() -> + #{ + relu => fun(X) -> max(0, X) end, + sigmoid => fun(X) -> 1 / (1 + math:exp(-X)) end, + tanh => fun(X) -> math:tanh(X) end, + softmax => fun softmax/1 }. -train_model(ModelName, TrainingData, State) -> - Model = maps:get(ModelName, State#state.models), - {UpdatedModel, Stats} = case ModelName of - play_strategy -> - train_play_strategy(Model, TrainingData); - card_combination -> - train_card_combination(Model, TrainingData); - opponent_prediction -> - train_opponent_prediction(Model, TrainingData); - game_state_evaluation -> - train_game_state(Model, TrainingData) +%% 训练相关函数 +do_train(Data, Options, State) -> + try + % 数据预处理 + ProcessedData = preprocess_data(Data, State), + % 划分训练集和验证集 + {TrainData, ValidData} = split_train_valid(ProcessedData), + % 训练模型 + {NewModel, TrainMetrics} = train_model(TrainData, ValidData, Options, State), + % 更新状态 + NewState = update_state_after_training(State, NewModel, TrainMetrics), + {{ok, TrainMetrics}, NewState} + catch + Error:Reason -> + {{error, {Error, Reason}}, State} + end. + +%% 预测相关函数 +do_predict(Features, Options, State) -> + try + % 特征预处理 + ProcessedFeatures = preprocess_features(Features, State), + % 执行预测 + Prediction = forward_pass(ProcessedFeatures, State#state.model), + % 后处理预测结果 + ProcessedPrediction = postprocess_prediction(Prediction, Options), + {ProcessedPrediction, State} + catch + Error:Reason -> + {{error, {Error, Reason}}, State} + end. + +%% 模型更新函数 +do_update_model(State, NewModel, Options) -> + ValidatedModel = validate_model(NewModel), + State#state{ + model = ValidatedModel, + model_version = State#state.model_version + 1, + last_update = os:timestamp() + }. + +%% 添加训练样本 +do_add_training_sample(State, Sample) -> + ValidatedSample = validate_sample(Sample), + NewTrainingData = [ValidatedSample | State#state.training_data], + State#state{training_data = NewTrainingData}. + +%% 数据预处理函数 +preprocess_data(Data, State) -> + % 特征提取 + Features = extract_features(Data), + % 特征归一化 + NormalizedFeatures = normalize_features(Features, State#state.model.normalization_params), + % 特征选择 + SelectedFeatures = select_features(NormalizedFeatures, State#state.feature_config), + % 增强特征 + AugmentedFeatures = augment_features(SelectedFeatures), + AugmentedFeatures. + +%% 特征处理函数 +preprocess_features(Features, State) -> + % 特征归一化 + NormalizedFeatures = normalize_features(Features, State#state.model.normalization_params), + % 特征转换 + TransformedFeatures = transform_features(NormalizedFeatures), + TransformedFeatures. + +%% 模型训练函数 +train_model(TrainData, ValidData, Options, State) -> + InitialModel = State#state.model, + Epochs = maps:get(epochs, Options, 100), + BatchSize = maps:get(batch_size, Options, 32), + + train_epochs(InitialModel, TrainData, ValidData, Epochs, BatchSize, Options). + +train_epochs(Model, _, _, 0, _, _) -> + {Model, []}; +train_epochs(Model, TrainData, ValidData, Epochs, BatchSize, Options) -> + % 创建批次 + Batches = create_batches(TrainData, BatchSize), + % 训练一个epoch + {UpdatedModel, EpochMetrics} = train_epoch(Model, Batches, ValidData, Options), + % 检查是否需要提前停止 + case should_early_stop(EpochMetrics, Options) of + true -> + {UpdatedModel, EpochMetrics}; + false -> + train_epochs(UpdatedModel, TrainData, ValidData, Epochs-1, BatchSize, Options) + end. + +train_epoch(Model, Batches, ValidData, Options) -> + % 训练所有批次 + {TrainedModel, BatchMetrics} = train_batches(Model, Batches, Options), + % 在验证集上评估 + ValidationMetrics = evaluate_model(TrainedModel, ValidData), + % 合并指标 + EpochMetrics = merge_metrics(BatchMetrics, ValidationMetrics), + {TrainedModel, EpochMetrics}. + +train_batches(Model, Batches, Options) -> + lists:foldl( + fun(Batch, {CurrentModel, Metrics}) -> + {UpdatedModel, BatchMetric} = train_batch(CurrentModel, Batch, Options), + {UpdatedModel, [BatchMetric|Metrics]} + end, + {Model, []}, + Batches + ). + +train_batch(Model, Batch, Options) -> + % 前向传播 + {Predictions, CacheData} = forward_pass_with_cache(Model, Batch), + % 计算损失 + {Loss, LossGrad} = calculate_loss(Predictions, Batch, Options), + % 反向传播 + Gradients = backward_pass(LossGrad, CacheData, Model), + % 更新模型参数 + UpdatedModel = update_model_parameters(Model, Gradients, Options), + % 返回更新后的模型和训练指标 + {UpdatedModel, #{loss => Loss}}. + +%% 模型评估函数 +evaluate_model(Model, Data) -> + % 前向传播 + Predictions = forward_pass(Model, Data), + % 计算各种评估指标 + #{ + accuracy => calculate_accuracy(Predictions, Data), + precision => calculate_precision(Predictions, Data), + recall => calculate_recall(Predictions, Data), + f1_score => calculate_f1_score(Predictions, Data) + }. + +%% 工具函数 +random_matrix(Rows, Cols) -> + [ + [rand:normal() / math:sqrt(Rows) || _ <- lists:seq(1, Cols)] + || _ <- lists:seq(1, Rows) + ]. + +zeros_vector(Size) -> + [0.0 || _ <- lists:seq(1, Size)]. + +softmax(X) -> + Exp = [math:exp(Xi) || Xi <- X], + Sum = lists:sum(Exp), + [E / Sum || E <- Exp]. + +create_batches(Data, BatchSize) -> + create_batches(Data, BatchSize, []). + +create_batches([], _, Acc) -> + lists:reverse(Acc); +create_batches(Data, BatchSize, Acc) -> + {Batch, Rest} = case length(Data) of + N when N > BatchSize -> + lists:split(BatchSize, Data); + _ -> + {Data, []} end, - {UpdatedModel, Stats}. - -make_prediction(Model, Features) -> - % 使用模型权重和特征进行预测 - Weights = Model#model_data.weights, - calculate_prediction(Features, Weights). - -update_model_weights(Model, NewData, Reward) -> - % 使用强化学习更新模型权重 - CurrentWeights = Model#model_data.weights, - LearningRate = Model#model_data.learning_rate, - UpdatedWeights = apply_reinforcement_learning(CurrentWeights, NewData, Reward, LearningRate), - Model#model_data{weights = UpdatedWeights, iterations = Model#model_data.iterations + 1}. - -calculate_prediction(Features, Weights) -> - % 实现预测算法 + create_batches(Rest, BatchSize, [Batch|Acc]). + +%% 保存和加载模型 +do_save_model(State, Filename) -> + ModelData = #{ + model => State#state.model, + version => State#state.model_version, + hyperparameters => State#state.hyperparameters, + feature_config => State#state.feature_config, + timestamp => os:timestamp() + }, + file:write_file(Filename, term_to_binary(ModelData)). + +do_load_model(Filename) -> + case file:read_file(Filename) of + {ok, Binary} -> + try + ModelData = binary_to_term(Binary), + {ok, create_state_from_model_data(ModelData)} + catch + _:_ -> {error, invalid_model_file} + end; + Error -> + Error + end. + +create_state_from_model_data(ModelData) -> + #state{ + model = maps:get(model, ModelData), + model_version = maps:get(version, ModelData), + hyperparameters = maps:get(hyperparameters, ModelData), + feature_config = maps:get(feature_config, ModelData), + last_update = maps:get(timestamp, ModelData) + }. + +%% 性能指标更新 +update_performance_metrics(State) -> + NewMetrics = calculate_current_metrics(State), + State#state{ + performance_metrics = [NewMetrics | State#state.performance_metrics] + }. + +calculate_current_metrics(State) -> + Model = State#state.model, + #{ + loss => Model#model.last_train_error, + timestamp => os:timestamp() + }. + +%% 状态更新 +update_state_after_training(State, NewModel, Metrics) -> + State#state{ + model = NewModel, + model_version = State#state.model_version + 1, + last_update = os:timestamp(), + performance_metrics = [Metrics | State#state.performance_metrics] + }. + +%% 验证模型(续) +validate_model(Model) -> + % 验证权重和偏置 + ValidatedWeights = validate_weights(Model#model.weights), + ValidatedBiases = validate_biases(Model#model.biases), + % 验证网络层配置 + ValidatedLayers = validate_layers(Model#model.layers), + % 验证激活函数 + ValidatedActivations = validate_activation_functions(Model#model.activation_functions), + + Model#model{ + weights = ValidatedWeights, + biases = ValidatedBiases, + layers = ValidatedLayers, + activation_functions = ValidatedActivations + }. + +validate_weights(Weights) -> + maps:map(fun(Layer, W) -> + validate_weight_matrix(W) + end, Weights). + +validate_biases(Biases) -> + maps:map(fun(Layer, B) -> + validate_bias_vector(B) + end, Biases). + +validate_layers(Layers) -> + lists:map(fun validate_layer/1, Layers). + +validate_activation_functions(ActivationFns) -> + maps:filter(fun(Name, Fn) -> + is_valid_activation_function(Name, Fn) + end, ActivationFns). + +%% 前向传播相关函数 +forward_pass(Model, Input) -> + {Output, _Cache} = forward_pass_with_cache(Model, Input), + Output. + +forward_pass_with_cache(Model, Input) -> + InitialCache = #{input => Input}, lists:foldl( - fun({Feature, Value}, Acc) -> - Weight = maps:get(Feature, Weights, 0), - Acc + (Value * Weight) + fun(Layer, {CurrentInput, Cache}) -> + {Output, LayerCache} = forward_layer(Layer, CurrentInput, Model), + {Output, Cache#{get_layer_name(Layer) => LayerCache}} + end, + {Input, InitialCache}, + Model#model.layers + ). + +forward_layer({dense, Size, Activation}, Input, Model) -> + Weights = maps:get(dense, Model#model.weights), + Bias = maps:get(dense, Model#model.biases), + % 线性变换 + Z = matrix_multiply(Input, Weights) + Bias, + % 激活函数 + ActivationFn = maps:get(Activation, Model#model.activation_functions), + Output = ActivationFn(Z), + {Output, #{pre_activation => Z, output => Output}}; + +forward_layer({dropout, Rate}, Input, Model) -> + case get_training_mode(Model) of + true -> + Mask = generate_dropout_mask(Input, Rate), + Output = element_wise_multiply(Input, Mask), + {Output, #{mask => Mask}}; + false -> + {Input, #{}} + end. + +%% 反向传播相关函数 +backward_pass(LossGrad, Cache, Model) -> + {_, Gradients} = lists:foldr( + fun(Layer, {CurrentGrad, LayerGrads}) -> + LayerCache = maps:get(get_layer_name(Layer), Cache), + {NextGrad, LayerGrad} = backward_layer(Layer, CurrentGrad, LayerCache, Model), + {NextGrad, [LayerGrad | LayerGrads]} end, - 0, - Features - ). \ No newline at end of file + {LossGrad, []}, + Model#model.layers + ), + consolidate_gradients(Gradients). + +backward_layer({dense, Size, Activation}, Grad, Cache, Model) -> + % 获取激活函数导数 + ActivationGrad = get_activation_gradient(Activation), + % 计算激活函数的梯度 + PreAct = maps:get(pre_activation, Cache), + DZ = element_wise_multiply(Grad, ActivationGrad(PreAct)), + % 计算权重和偏置的梯度 + Input = maps:get(input, Cache), + WeightGrad = matrix_multiply(transpose(Input), DZ), + BiasGrad = sum_columns(DZ), + % 计算输入的梯度 + Weights = maps:get(dense, Model#model.weights), + InputGrad = matrix_multiply(DZ, transpose(Weights)), + {InputGrad, #{weights => WeightGrad, bias => BiasGrad}}; + +backward_layer({dropout, Rate}, Grad, Cache, _Model) -> + Mask = maps:get(mask, Cache), + {element_wise_multiply(Grad, Mask), #{}}. + +%% 损失函数相关 +calculate_loss(Predictions, Targets, Options) -> + LossType = maps:get(loss_type, Options, cross_entropy), + calculate_loss_by_type(LossType, Predictions, Targets). + +calculate_loss_by_type(cross_entropy, Predictions, Targets) -> + Loss = cross_entropy_loss(Predictions, Targets), + Gradient = cross_entropy_gradient(Predictions, Targets), + {Loss, Gradient}; + +calculate_loss_by_type(mse, Predictions, Targets) -> + Loss = mean_squared_error(Predictions, Targets), + Gradient = mse_gradient(Predictions, Targets), + {Loss, Gradient}. + +%% 优化器相关函数 +update_model_parameters(Model, Gradients, Options) -> + Optimizer = maps:get(optimizer, Options, adam), + LearningRate = maps:get(learning_rate, Options, 0.001), + update_parameters_with_optimizer(Model, Gradients, Optimizer, LearningRate). + +update_parameters_with_optimizer(Model, Gradients, adam, LearningRate) -> + % Adam优化器实现 + Beta1 = 0.9, + Beta2 = 0.999, + Epsilon = 1.0e-8, + + % 更新动量 + {NewWeights, NewMomentum} = update_adam_parameters( + Model#model.weights, + maps:get(weights, Gradients), + maps:get(momentum, Model, #{}), + LearningRate, + Beta1, + Beta2, + Epsilon + ), + + Model#model{ + weights = NewWeights, + momentum = NewMomentum + }; + +update_parameters_with_optimizer(Model, Gradients, sgd, LearningRate) -> + % SGD优化器实现 + NewWeights = update_sgd_parameters( + Model#model.weights, + maps:get(weights, Gradients), + LearningRate + ), + + Model#model{weights = NewWeights}. + +%% 特征工程相关函数 +extract_features(Data) -> + lists:map(fun extract_sample_features/1, Data). + +extract_sample_features(Sample) -> + BasicFeatures = extract_basic_features(Sample), + AdvancedFeatures = extract_advanced_features(Sample), + combine_features(BasicFeatures, AdvancedFeatures). + +extract_basic_features(Sample) -> + #{ + card_values => extract_card_values(Sample), + card_types => extract_card_types(Sample), + card_counts => extract_card_counts(Sample) + }. + +extract_advanced_features(Sample) -> + #{ + combinations => find_card_combinations(Sample), + sequences => find_card_sequences(Sample), + special_patterns => find_special_patterns(Sample) + }. + +%% 矩阵操作辅助函数 +matrix_multiply(A, B) -> + % 矩阵乘法实现 + case {matrix_dimensions(A), matrix_dimensions(B)} of + {{RowsA, ColsA}, {RowsB, ColsB}} when ColsA =:= RowsB -> + do_matrix_multiply(A, B, RowsA, ColsB); + _ -> + error(matrix_dimension_mismatch) + end. + +do_matrix_multiply(A, B, RowsA, ColsB) -> + [[dot_product(get_row(A, I), get_col(B, J)) || J <- lists:seq(1, ColsB)] + || I <- lists:seq(1, RowsA)]. + +dot_product(Vec1, Vec2) -> + lists:sum([X * Y || {X, Y} <- lists:zip(Vec1, Vec2)]). + +transpose(Matrix) -> + case Matrix of + [] -> []; + [[]|_] -> []; + _ -> + [get_col(Matrix, I) || I <- lists:seq(1, length(hd(Matrix)))] + end. + +%% 工具函数 +get_layer_name({Type, Size, _}) -> + atom_to_list(Type) ++ "_" ++ integer_to_list(Size); +get_layer_name({Type, Rate}) -> + atom_to_list(Type) ++ "_" ++ float_to_list(Rate). + +generate_dropout_mask(Input, Rate) -> + Size = matrix_dimensions(Input), + [[case rand:uniform() < Rate of true -> 0.0; false -> 1.0 end + || _ <- lists:seq(1, Size)] + || _ <- lists:seq(1, Size)]. + +element_wise_multiply(A, B) -> + [[X * Y || {X, Y} <- lists:zip(RowA, RowB)] + || {RowA, RowB} <- lists:zip(A, B)]. + +sum_columns(Matrix) -> + lists:foldl( + fun(Row, Acc) -> + [X + Y || {X, Y} <- lists:zip(Row, Acc)] + end, + lists:duplicate(length(hd(Matrix)), 0.0), + Matrix + ). + +matrix_dimensions([]) -> {0, 0}; +matrix_dimensions([[]|_]) -> {0, 0}; +matrix_dimensions(Matrix) -> + {length(Matrix), length(hd(Matrix))}. + +get_row(Matrix, I) -> + lists:nth(I, Matrix). + +get_col(Matrix, J) -> + [lists:nth(J, Row) || Row <- Matrix]. + +%% 初始化和保存最终状态 +save_final_state(State) -> + Filename = "ml_model_" ++ format_timestamp() ++ ".state", + do_save_model(State, Filename). + +format_timestamp() -> + {{Year, Month, Day}, {Hour, Minute, Second}} = calendar:universal_time(), + lists:flatten(io_lib:format("~4..0w~2..0w~2..0w_~2..0w~2..0w~2..0w", + [Year, Month, Day, Hour, Minute, Second])). \ No newline at end of file diff --git a/src/opponent_modeling.erl b/src/opponent_modeling.erl deleted file mode 100644 index 88454a9..0000000 --- a/src/opponent_modeling.erl +++ /dev/null @@ -1,62 +0,0 @@ --module(opponent_modeling). --export([create_model/0, update_model/2, analyze_opponent/2, predict_play/2]). - --record(opponent_model, { - play_patterns = #{}, % 出牌模式统计 - card_preferences = #{}, % 牌型偏好 - risk_profile = 0.5, % 风险偏好 - skill_rating = 500, % 技能评分 - play_history = [] % 历史出牌记录 -}). - -create_model() -> - #opponent_model{}. - -update_model(Model, GamePlay) -> - % 更新出牌模式统计 - NewPatterns = update_play_patterns(Model#opponent_model.play_patterns, GamePlay), - - % 更新牌型偏好 - NewPreferences = update_card_preferences(Model#opponent_model.card_preferences, GamePlay), - - % 更新风险偏好 - NewRiskProfile = calculate_risk_profile(Model#opponent_model.risk_profile, GamePlay), - - % 更新技能评分 - NewSkillRating = update_skill_rating(Model#opponent_model.skill_rating, GamePlay), - - % 更新历史记录 - NewHistory = [GamePlay | Model#opponent_model.play_history], - - Model#opponent_model{ - play_patterns = NewPatterns, - card_preferences = NewPreferences, - risk_profile = NewRiskProfile, - skill_rating = NewSkillRating, - play_history = lists:sublist(NewHistory, 100) % 保留最近100次出牌记录 - }. - -analyze_opponent(Model, GameState) -> - #{ - style => determine_play_style(Model), - strength => calculate_opponent_strength(Model), - predictability => calculate_predictability(Model), - weakness => identify_weaknesses(Model) - }. - -predict_play(Model, GameState) -> - % 基于历史模式预测 - HistoryBasedPrediction = predict_from_history(Model, GameState), - - % 基于牌型偏好预测 - PreferenceBasedPrediction = predict_from_preferences(Model, GameState), - - % 基于风险偏好预测 - RiskBasedPrediction = predict_from_risk_profile(Model, GameState), - - % 综合预测结果 - combine_predictions([ - {HistoryBasedPrediction, 0.4}, - {PreferenceBasedPrediction, 0.3}, - {RiskBasedPrediction, 0.3} - ]). \ No newline at end of file diff --git a/src/parallel_compute.erl b/src/parallel_compute.erl deleted file mode 100644 index 50ce221..0000000 --- a/src/parallel_compute.erl +++ /dev/null @@ -1,55 +0,0 @@ --module(parallel_compute). --behaviour(gen_server). - --export([start_link/0, init/1, handle_call/3, handle_cast/2, handle_info/2, terminate/2, code_change/3]). --export([parallel_predict/2, batch_process/2]). - --record(state, { - worker_pool = [], % 工作进程池 - job_queue = [], % 任务队列 - results = #{}, % 结果集 - pool_size = 4 % 默认工作进程数 -}). - -%% API -start_link() -> - gen_server:start_link({local, ?MODULE}, ?MODULE, [], []). - -parallel_predict(Inputs, Model) -> - gen_server:call(?MODULE, {parallel_predict, Inputs, Model}). - -batch_process(BatchData, ProcessFun) -> - gen_server:call(?MODULE, {batch_process, BatchData, ProcessFun}). - -%% 内部函数 - -initialize_worker_pool(PoolSize) -> - [spawn_worker() || _ <- lists:seq(1, PoolSize)]. - -spawn_worker() -> - spawn_link(fun() -> worker_loop() end). - -worker_loop() -> - receive - {process, Data, From} -> - Result = process_data(Data), - From ! {result, self(), Result}, - worker_loop(); - stop -> - ok - end. - -process_data({predict, Input, Model}) -> - deep_learning:predict(Model, Input); -process_data({custom, Fun, Data}) -> - Fun(Data). - -distribute_work(Workers, Jobs) -> - distribute_work(Workers, Jobs, #{}). - -distribute_work(_, [], Results) -> - Results; -distribute_work(Workers, [Job|Jobs], Results) -> - [Worker|RestWorkers] = Workers, - Worker ! {process, Job, self()}, - distribute_work(RestWorkers ++ [Worker], Jobs, Results). \ No newline at end of file diff --git a/src/performance_optimization.erl b/src/performance_optimization.erl index d83cc63..993b023 100644 --- a/src/performance_optimization.erl +++ b/src/performance_optimization.erl @@ -34,4 +34,29 @@ analyze_resource_usage() -> cpu => Usage, memory => Memory, process_count => erlang:system_info(process_count) - }. \ No newline at end of file + }. + +%% 处理同步调用 +handle_call(get_stats, _From, State) -> + {reply, {ok, State#state.performance_history}, State}; +handle_call(_Request, _From, State) -> + {reply, {error, unknown_call}, State}. + +%% 处理异步消息 +handle_cast(optimize, State) -> + ResourceUsage = analyze_resource_usage(), + OptimizationActions = calculate_optimization_actions(ResourceUsage, State#state.optimization_rules), + NewState = apply_optimization_actions(OptimizationActions, State), + {noreply, NewState#state{resource_usage = ResourceUsage}}; +handle_cast(_Msg, State) -> + {noreply, State}. + +%% 辅助函数 +calculate_optimization_actions(ResourceUsage, Rules) -> + % 简化版实现 + [balance_load, free_memory]. + +apply_optimization_actions(Actions, State) -> + % 简化版实现 + NewHistory = [{os:timestamp(), Actions} | State#state.performance_history], + State#state{performance_history = lists:sublist(NewHistory, 100)}. \ No newline at end of file diff --git a/src/system_supervisor.erl b/src/system_supervisor.erl index 5cbb681..214f6ef 100644 --- a/src/system_supervisor.erl +++ b/src/system_supervisor.erl @@ -4,6 +4,21 @@ -export([start_link/0, init/1]). -export([start_system/0, stop_system/0, system_status/0]). +%% 启动整个系统 +start_system() -> + application:ensure_all_started(cardSrv). + +%% 停止整个系统 +stop_system() -> + application:stop(cardSrv). + +%% 获取系统状态 +system_status() -> + [{supervisor, ?MODULE}, + {children, supervisor:which_children(?MODULE)}, + {memory, erlang:memory()}, + {process_count, erlang:system_info(process_count)}]. + start_link() -> supervisor:start_link({local, ?MODULE}, ?MODULE, []). diff --git a/src/test_suite.erl b/src/test_suite.erl deleted file mode 100644 index 2db805c..0000000 --- a/src/test_suite.erl +++ /dev/null @@ -1,37 +0,0 @@ --module(test_suite). --export([run_full_test/0, validate_ai_performance/1]). - -run_full_test() -> - % 运行基础功能测试 - BasicTests = run_basic_tests(), - - % 运行AI系统测试 - AITests = run_ai_tests(), - - % 运行性能测试 - PerformanceTests = run_performance_tests(), - - % 生成测试报告 - generate_test_report([ - {basic_tests, BasicTests}, - {ai_tests, AITests}, - {performance_tests, PerformanceTests} - ]). - -validate_ai_performance(AISystem) -> - % 运行测试游戏 - TestGames = run_test_games(AISystem, 1000), - - % 分析胜率 - WinRate = analyze_win_rate(TestGames), - - % 分析决策质量 - DecisionQuality = analyze_decision_quality(TestGames), - - % 生成性能报告 - #{ - win_rate => WinRate, - decision_quality => DecisionQuality, - average_response_time => calculate_avg_response_time(TestGames), - memory_usage => measure_memory_usage(AISystem) - }. \ No newline at end of file diff --git a/src/training_system.erl b/src/training_system.erl deleted file mode 100644 index 1d9a3ae..0000000 --- a/src/training_system.erl +++ /dev/null @@ -1,25 +0,0 @@ --module(training_system). --export([start_training/0, process_game_data/1, update_models/1]). - -%% 开始训练过程 -start_training() -> - TrainingData = load_training_data(), - Models = initialize_models(), - train_models(Models, TrainingData, [ - {epochs, 1000}, - {batch_size, 32}, - {learning_rate, 0.001} - ]). - -%% 处理游戏数据 -process_game_data(GameRecord) -> - Features = extract_features(GameRecord), - Labels = extract_labels(GameRecord), - update_training_dataset(Features, Labels). - -%% 更新模型 -update_models(NewData) -> - CurrentModels = get_current_models(), - UpdatedModels = retrain_models(CurrentModels, NewData), - validate_models(UpdatedModels), - deploy_models(UpdatedModels). \ No newline at end of file