@ -0,0 +1 @@ | |||
[]. |
@ -1,73 +0,0 @@ | |||
-module(advanced_learning). | |||
-export([init/0, train/2, predict/2, update/3]). | |||
-record(learning_state, { | |||
neural_network, % 深度神经网络模型 | |||
experience_buffer, % 经验回放缓冲 | |||
model_version, % 模型版本 | |||
training_stats % 训练统计 | |||
}). | |||
%% 神经网络配置 | |||
-define(NETWORK_CONFIG, [ | |||
{input_layer, 512}, | |||
{hidden_layer_1, 256}, | |||
{hidden_layer_2, 128}, | |||
{hidden_layer_3, 64}, | |||
{output_layer, 32} | |||
]). | |||
%% 训练参数 | |||
-define(LEARNING_RATE, 0.001). | |||
-define(BATCH_SIZE, 64). | |||
-define(EXPERIENCE_BUFFER_SIZE, 10000). | |||
init() -> | |||
Network = initialize_neural_network(?NETWORK_CONFIG), | |||
#learning_state{ | |||
neural_network = Network, | |||
experience_buffer = queue:new(), | |||
model_version = 1, | |||
training_stats = #{ | |||
total_games = 0, | |||
win_rate = 0.0, | |||
avg_reward = 0.0 | |||
} | |||
}. | |||
train(State, TrainingData) -> | |||
% 准备批次数据 | |||
Batch = prepare_batch(TrainingData, ?BATCH_SIZE), | |||
% 执行深度学习训练 | |||
{UpdatedNetwork, Loss} = train_network(State#learning_state.neural_network, Batch), | |||
% 更新统计信息 | |||
NewStats = update_training_stats(State#learning_state.training_stats, Loss), | |||
% 返回更新后的状态 | |||
State#learning_state{ | |||
neural_network = UpdatedNetwork, | |||
training_stats = NewStats | |||
}. | |||
predict(State, Input) -> | |||
% 使用神经网络进行预测 | |||
Features = extract_features(Input), | |||
Prediction = neural_network:forward(State#learning_state.neural_network, Features), | |||
process_prediction(Prediction). | |||
update(State, Experience, Reward) -> | |||
% 更新经验缓冲 | |||
NewBuffer = update_experience_buffer(State#learning_state.experience_buffer, | |||
Experience, | |||
Reward), | |||
% 判断是否需要进行训练 | |||
case should_train(NewBuffer) of | |||
true -> | |||
TrainingData = prepare_training_data(NewBuffer), | |||
train(State#learning_state{experience_buffer = NewBuffer}, TrainingData); | |||
false -> | |||
State#learning_state{experience_buffer = NewBuffer} | |||
end. |
@ -1,176 +1,657 @@ | |||
-module(ai_core). | |||
-export([init_ai/1, make_decision/2, update_strategy/3]). | |||
-export([ | |||
init/1, | |||
make_decision/2, | |||
evaluate_hand/1, | |||
predict_plays/2, | |||
update_knowledge/2, | |||
calculate_win_rate/1 | |||
]). | |||
-include("card_types.hrl"). | |||
-record(ai_state, { | |||
personality, % aggressive | conservative | balanced | |||
strategy_weights, % 策略权重 | |||
knowledge_base, % 知识库 | |||
game_history = [] % 游戏历史 | |||
role, % dizhu | nongmin | |||
hand_cards = [], % 当前手牌 | |||
known_cards = [], % 已知的牌 | |||
played_cards = [], % 已打出的牌 | |||
player_history = [], % 玩家出牌历史 | |||
game_stage, % early_game | mid_game | end_game | |||
strategy_cache = #{} % 策略缓存 | |||
}). | |||
%% AI初始化 | |||
init_ai(Personality) -> | |||
-record(play_context, { | |||
last_play, % 上一手牌 | |||
cards_remaining, % 剩余手牌数 | |||
num_greater_cards, % 比上家大的牌数量 | |||
control_factor, % 控制因子 | |||
must_play = false % 是否必须出牌 | |||
}). | |||
%% 初始化AI状态 | |||
init(Role) -> | |||
#ai_state{ | |||
personality = Personality, | |||
strategy_weights = init_weights(Personality), | |||
knowledge_base = init_knowledge_base() | |||
role = Role, | |||
game_stage = early_game | |||
}. | |||
%% 决策制定 | |||
%% 做出决策 | |||
make_decision(AIState, GameState) -> | |||
% 分析当前局势 | |||
Situation = analyze_situation(GameState), | |||
% 生成可能的行动 | |||
PossiblePlays = generate_possible_plays(GameState), | |||
% 评估每个行动 | |||
RatedPlays = evaluate_plays(PossiblePlays, AIState, Situation), | |||
Context = analyze_context(AIState, GameState), | |||
case should_play(Context, AIState) of | |||
true -> | |||
select_best_play(Context, AIState); | |||
false -> | |||
{pass, AIState} | |||
end. | |||
%% 评估手牌 | |||
evaluate_hand(Cards) -> | |||
Components = analyze_components(Cards), | |||
calculate_hand_value(Components). | |||
%% 预测可能的出牌 | |||
predict_plays(Cards, LastPlay) -> | |||
ValidPlays = generate_valid_plays(Cards, LastPlay), | |||
score_potential_plays(ValidPlays, LastPlay). | |||
%% 更新知识库 | |||
update_knowledge(AIState, Event) -> | |||
update_state_with_event(AIState, Event). | |||
%% 计算胜率 | |||
calculate_win_rate(AIState) -> | |||
HandStrength = evaluate_hand(AIState#ai_state.hand_cards), | |||
PositionValue = evaluate_position(AIState), | |||
ControlValue = evaluate_control(AIState), | |||
% 选择最佳行动 | |||
select_best_play(RatedPlays, AIState). | |||
calculate_probability(HandStrength, PositionValue, ControlValue). | |||
%% 策略更新 | |||
update_strategy(AIState, GameResult, GameHistory) -> | |||
NewWeights = adjust_weights(AIState#ai_state.strategy_weights, GameResult), | |||
NewKnowledge = update_knowledge(AIState#ai_state.knowledge_base, GameHistory), | |||
AIState#ai_state{ | |||
strategy_weights = NewWeights, | |||
knowledge_base = NewKnowledge, | |||
game_history = [GameHistory | AIState#ai_state.game_history] | |||
%% 内部函数 | |||
%% 分析上下文 | |||
analyze_context(AIState, GameState) -> | |||
LastPlay = get_last_play(GameState), | |||
CardsRemaining = length(AIState#ai_state.hand_cards), | |||
GreaterCards = count_greater_cards(AIState#ai_state.hand_cards, LastPlay), | |||
ControlFactor = calculate_control_factor(AIState, GameState), | |||
#play_context{ | |||
last_play = LastPlay, | |||
cards_remaining = CardsRemaining, | |||
num_greater_cards = GreaterCards, | |||
control_factor = ControlFactor, | |||
must_play = must_play(AIState, GameState) | |||
}. | |||
%% 内部函数 | |||
%% 判断是否应该出牌 | |||
should_play(Context, AIState) -> | |||
case Context#play_context.must_play of | |||
true -> true; | |||
false -> | |||
case Context#play_context.last_play of | |||
none -> true; | |||
_ -> | |||
should_beat_last_play(Context, AIState) | |||
end | |||
end. | |||
init_weights(aggressive) -> | |||
#{ | |||
control_weight => 0.8, | |||
attack_weight => 0.7, | |||
defense_weight => 0.3, | |||
risk_weight => 0.6 | |||
}; | |||
init_weights(conservative) -> | |||
#{ | |||
control_weight => 0.5, | |||
attack_weight => 0.4, | |||
defense_weight => 0.8, | |||
risk_weight => 0.3 | |||
}; | |||
init_weights(balanced) -> | |||
%% 选择最佳出牌 | |||
select_best_play(Context, AIState) -> | |||
Candidates = generate_candidates(AIState#ai_state.hand_cards, Context), | |||
ScoredPlays = [ | |||
{score_play(Play, Context, AIState), Play} | |||
|| Play <- Candidates | |||
], | |||
select_highest_scored_play(ScoredPlays, AIState). | |||
%% 分析牌型组件 | |||
analyze_components(Cards) -> | |||
GroupedCards = group_cards_by_value(Cards), | |||
#{ | |||
control_weight => 0.6, | |||
attack_weight => 0.6, | |||
defense_weight => 0.6, | |||
risk_weight => 0.5 | |||
singles => find_singles(GroupedCards), | |||
pairs => find_pairs(GroupedCards), | |||
triples => find_triples(GroupedCards), | |||
sequences => find_sequences(GroupedCards), | |||
bombs => find_bombs(GroupedCards) | |||
}. | |||
analyze_situation(GameState) -> | |||
#{ | |||
hand_strength => evaluate_hand_strength(GameState), | |||
control_status => evaluate_control(GameState), | |||
opponent_cards => estimate_opponent_cards(GameState), | |||
game_stage => determine_game_stage(GameState) | |||
%% 计算手牌价值 | |||
calculate_hand_value(Components) -> | |||
SinglesValue = calculate_singles_value(maps:get(singles, Components, [])), | |||
PairsValue = calculate_pairs_value(maps:get(pairs, Components, [])), | |||
TriplesValue = calculate_triples_value(maps:get(triples, Components, [])), | |||
SequencesValue = calculate_sequences_value(maps:get(sequences, Components, [])), | |||
BombsValue = calculate_bombs_value(maps:get(bombs, Components, [])), | |||
SinglesValue + PairsValue + TriplesValue + SequencesValue + BombsValue. | |||
%% 生成有效出牌选择 | |||
generate_valid_plays(Cards, LastPlay) -> | |||
case LastPlay of | |||
none -> | |||
generate_leading_plays(Cards); | |||
Play -> | |||
generate_following_plays(Cards, Play) | |||
end. | |||
%% 评分潜在出牌 | |||
score_potential_plays(Plays, LastPlay) -> | |||
[{Play, score_play_potential(Play, LastPlay)} || Play <- Plays]. | |||
%% 更新状态 | |||
update_state_with_event(AIState, {play_cards, Player, Cards}) -> | |||
NewPlayed = AIState#ai_state.played_cards ++ Cards, | |||
NewHistory = [{Player, Cards} | AIState#ai_state.player_history], | |||
AIState#ai_state{ | |||
played_cards = NewPlayed, | |||
player_history = NewHistory | |||
}; | |||
update_state_with_event(AIState, {game_stage, NewStage}) -> | |||
AIState#ai_state{game_stage = NewStage}; | |||
update_state_with_event(AIState, _) -> | |||
AIState. | |||
%% 计算胜率概率 | |||
calculate_probability(HandStrength, PositionValue, ControlValue) -> | |||
BaseProb = HandStrength * 0.5 + PositionValue * 0.3 + ControlValue * 0.2, | |||
normalize_probability(BaseProb). | |||
%% 计算控制因子 | |||
calculate_control_factor(AIState, GameState) -> | |||
ControlCards = count_control_cards(AIState#ai_state.hand_cards), | |||
TotalCards = count_total_remaining_cards(GameState), | |||
RemainingCards = length(AIState#ai_state.hand_cards), | |||
ControlRatio = ControlCards / max(1, RemainingCards), | |||
PositionBonus = calculate_position_bonus(AIState, GameState), | |||
ControlRatio * PositionBonus. | |||
%% 判断是否必须出牌 | |||
must_play(AIState, GameState) -> | |||
is_current_player(AIState, GameState) andalso | |||
not has_active_play(GameState). | |||
%% 评估是否应该大过上家 | |||
should_beat_last_play(Context, AIState) -> | |||
case Context#play_context.last_play of | |||
none -> true; | |||
LastPlay -> | |||
HandStrength = evaluate_hand(AIState#ai_state.hand_cards), | |||
ControlLevel = Context#play_context.control_factor, | |||
CardsLeft = Context#play_context.cards_remaining, | |||
should_beat(HandStrength, ControlLevel, CardsLeft, LastPlay) | |||
end. | |||
%% 生成候选出牌 | |||
generate_candidates(Cards, Context) -> | |||
BasePlays = case Context#play_context.last_play of | |||
none -> generate_leading_plays(Cards); | |||
LastPlay -> generate_following_plays(Cards, LastPlay) | |||
end, | |||
filter_candidates(BasePlays, Context). | |||
%% 对出牌进行评分 | |||
score_play(Play, Context, AIState) -> | |||
BaseScore = calculate_base_score(Play), | |||
TempoScore = calculate_tempo_score(Play, Context), | |||
ControlScore = calculate_control_score(Play, Context, AIState), | |||
EfficiencyScore = calculate_efficiency_score(Play, Context), | |||
FinalScore = BaseScore * 0.4 + | |||
TempoScore * 0.2 + | |||
ControlScore * 0.3 + | |||
EfficiencyScore * 0.1, | |||
adjust_score_for_context(FinalScore, Play, Context, AIState). | |||
%% 选择得分最高的出牌 | |||
select_highest_scored_play(ScoredPlays, AIState) -> | |||
case lists:sort(fun({Score1, _}, {Score2, _}) -> | |||
Score1 >= Score2 | |||
end, ScoredPlays) of | |||
[{Score, Play}|_] when Score > 0 -> | |||
{play, Play, update_after_play(Play, AIState)}; | |||
_ -> | |||
{pass, AIState} | |||
end. | |||
%% 计算单牌价值 | |||
calculate_singles_value(Singles) -> | |||
lists:sum([ | |||
case Value of | |||
V when V >= ?CARD_2 -> Value * 1.5; | |||
_ -> Value | |||
end || {Value, _} <- Singles | |||
]). | |||
%% 计算对子价值 | |||
calculate_pairs_value(Pairs) -> | |||
lists:sum([Value * 2.2 || {Value, _} <- Pairs]). | |||
%% 计算三张价值 | |||
calculate_triples_value(Triples) -> | |||
lists:sum([Value * 3.5 || {Value, _} <- Triples]). | |||
%% 计算顺子价值 | |||
calculate_sequences_value(Sequences) -> | |||
lists:sum([ | |||
Value * length(Cards) * 1.8 | |||
|| {Value, Cards} <- Sequences | |||
]). | |||
%% 计算炸弹价值 | |||
calculate_bombs_value(Bombs) -> | |||
lists:sum([Value * 10.0 || {Value, _} <- Bombs]). | |||
%% 生成首出牌型 | |||
generate_leading_plays(Cards) -> | |||
Components = analyze_components(Cards), | |||
Singles = generate_single_plays(Components), | |||
Pairs = generate_pair_plays(Components), | |||
Triples = generate_triple_plays(Components), | |||
Sequences = generate_sequence_plays(Components), | |||
Bombs = generate_bomb_plays(Components), | |||
Singles ++ Pairs ++ Triples ++ Sequences ++ Bombs. | |||
%% 生成跟牌选择 | |||
generate_following_plays(Cards, {Type, Value, _} = LastPlay) -> | |||
ValidPlays = find_greater_plays(Cards, Type, Value), | |||
BombPlays = find_bomb_plays(Cards), | |||
RocketPlay = find_rocket_play(Cards), | |||
ValidPlays ++ BombPlays ++ RocketPlay. | |||
%% 计算基础分数 | |||
calculate_base_score({Type, Value, Cards}) -> | |||
BaseValue = Value * length(Cards), | |||
TypeMultiplier = case Type of | |||
?CARD_TYPE_ROCKET -> 100.0; | |||
?CARD_TYPE_BOMB -> 80.0; | |||
?CARD_TYPE_STRAIGHT -> 40.0; | |||
?CARD_TYPE_STRAIGHT_PAIR -> 35.0; | |||
?CARD_TYPE_PLANE -> 30.0; | |||
?CARD_TYPE_THREE_TWO -> 25.0; | |||
?CARD_TYPE_THREE_ONE -> 20.0; | |||
?CARD_TYPE_THREE -> 15.0; | |||
?CARD_TYPE_PAIR -> 10.0; | |||
?CARD_TYPE_SINGLE -> 5.0 | |||
end, | |||
BaseValue * TypeMultiplier / 100.0. | |||
%% 计算节奏分数 | |||
calculate_tempo_score(Play, Context) -> | |||
case Context#play_context.game_stage of | |||
early_game -> calculate_early_tempo(Play, Context); | |||
mid_game -> calculate_mid_tempo(Play, Context); | |||
end_game -> calculate_end_tempo(Play, Context) | |||
end. | |||
%% 计算控制分数 | |||
calculate_control_score(Play, Context, AIState) -> | |||
{Type, Value, _} = Play, | |||
BaseControl = case Type of | |||
?CARD_TYPE_BOMB -> 1.0; | |||
?CARD_TYPE_ROCKET -> 1.0; | |||
_ when Value >= ?CARD_2 -> 0.8; | |||
_ -> 0.5 | |||
end, | |||
BaseControl * Context#play_context.control_factor. | |||
%% 计算效率分数 | |||
calculate_efficiency_score(Play, Context) -> | |||
{_, _, Cards} = Play, | |||
CardsUsed = length(Cards), | |||
RemainingCards = Context#play_context.cards_remaining - CardsUsed, | |||
Efficiency = CardsUsed / max(1, Context#play_context.cards_remaining), | |||
Efficiency * (1 + (20 - RemainingCards) / 20). | |||
%% 根据上下文调整分数 | |||
adjust_score_for_context(Score, Play, Context, AIState) -> | |||
RoleMultiplier = case AIState#ai_state.role of | |||
dizhu -> 1.2; | |||
nongmin -> 1.0 | |||
end, | |||
StageMultiplier = case Context#play_context.game_stage of | |||
early_game -> 0.9; | |||
mid_game -> 1.0; | |||
end_game -> 1.1 | |||
end, | |||
Score * RoleMultiplier * StageMultiplier. | |||
%% 更新出牌后的状态 | |||
update_after_play(Play, AIState) -> | |||
{_, _, Cards} = Play, | |||
NewHand = AIState#ai_state.hand_cards -- Cards, | |||
NewPlayed = AIState#ai_state.played_cards ++ Cards, | |||
AIState#ai_state{ | |||
hand_cards = NewHand, | |||
played_cards = NewPlayed | |||
}. | |||
generate_possible_plays(GameState) -> | |||
MyCards = get_my_cards(GameState), | |||
LastPlay = get_last_play(GameState), | |||
generate_valid_plays(MyCards, LastPlay). | |||
evaluate_plays(Plays, AIState, Situation) -> | |||
lists:map( | |||
fun(Play) -> | |||
Score = calculate_play_score(Play, AIState, Situation), | |||
{Play, Score} | |||
end, | |||
Plays | |||
). | |||
calculate_play_score(Play, AIState, Situation) -> | |||
Weights = AIState#ai_state.strategy_weights, | |||
ControlScore = evaluate_control_value(Play, Situation) * | |||
maps:get(control_weight, Weights), | |||
AttackScore = evaluate_attack_value(Play, Situation) * | |||
maps:get(attack_weight, Weights), | |||
DefenseScore = evaluate_defense_value(Play, Situation) * | |||
maps:get(defense_weight, Weights), | |||
RiskScore = evaluate_risk_value(Play, Situation) * | |||
maps:get(risk_weight, Weights), | |||
ControlScore + AttackScore + DefenseScore + RiskScore. | |||
select_best_play(RatedPlays, AIState) -> | |||
case AIState#ai_state.personality of | |||
aggressive -> | |||
select_aggressive(RatedPlays); | |||
conservative -> | |||
select_conservative(RatedPlays); | |||
balanced -> | |||
select_balanced(RatedPlays) | |||
end. | |||
%% 策略选择函数 | |||
select_aggressive(RatedPlays) -> | |||
% 倾向于选择得分最高的行动 | |||
{Play, _Score} = lists:max(RatedPlays), | |||
Play. | |||
select_conservative(RatedPlays) -> | |||
% 倾向于选择风险较低的行动 | |||
SafePlays = filter_safe_plays(RatedPlays), | |||
case SafePlays of | |||
[] -> select_balanced(RatedPlays); | |||
_ -> select_from_safe_plays(SafePlays) | |||
end. | |||
select_balanced(RatedPlays) -> | |||
% 在得分和风险之间寻找平衡 | |||
{Play, _Score} = select_balanced_play(RatedPlays), | |||
Play. | |||
%% 评估函数 | |||
evaluate_hand_strength(GameState) -> | |||
Cards = get_my_cards(GameState), | |||
calculate_hand_value(Cards). | |||
evaluate_control(GameState) -> | |||
% 评估是否控制局势 | |||
LastPlay = get_last_play(GameState), | |||
MyCards = get_my_cards(GameState), | |||
can_control_game(MyCards, LastPlay). | |||
estimate_opponent_cards(GameState) -> | |||
% 基于已出牌情况估计对手手牌 | |||
PlayedCards = get_played_cards(GameState), | |||
MyCards = get_my_cards(GameState), | |||
estimate_remaining_cards(PlayedCards, MyCards). | |||
%% 知识库更新 | |||
update_knowledge(KnowledgeBase, GameHistory) -> | |||
% 更新AI的知识库 | |||
NewPatterns = extract_patterns(GameHistory), | |||
merge_knowledge(KnowledgeBase, NewPatterns). | |||
extract_patterns(GameHistory) -> | |||
% 从游戏历史中提取出牌模式 | |||
lists:foldl( | |||
fun(Play, Patterns) -> | |||
Pattern = analyze_play_pattern(Play), | |||
update_pattern_stats(Pattern, Patterns) | |||
end, | |||
#{}, | |||
GameHistory | |||
). | |||
merge_knowledge(Old, New) -> | |||
maps:merge_with( | |||
fun(_Key, OldValue, NewValue) -> | |||
update_knowledge_value(OldValue, NewValue) | |||
end, | |||
Old, | |||
New | |||
). | |||
%% 辅助函数 | |||
normalize_probability(P) -> | |||
max(0.0, min(1.0, P)). | |||
count_control_cards(Cards) -> | |||
length([C || {V, _} = C <- Cards, V >= ?CARD_2]). | |||
calculate_position_bonus(AIState, GameState) -> | |||
case get_position(AIState, GameState) of | |||
first -> 1.2; | |||
middle -> 1.0; | |||
last -> 0.8 | |||
end. | |||
get_position(AIState, GameState) -> | |||
% 根据游戏状态判断位置 | |||
first. % 简化实现,实际需要根据具体游戏状态判断 | |||
is_current_player(AIState, GameState) -> | |||
% 判断是否当前玩家 | |||
true. % 简化实现,实际需要根据具体游戏状态判断 | |||
has_active_play(GameState) -> | |||
% 判断是否有活跃的出牌 | |||
false. % 简化实现,实际需要根据具体游戏状态判断 | |||
should_beat(HandStrength, ControlLevel, CardsLeft, LastPlay) -> | |||
BaseThreshold = 0.6, | |||
StrengthFactor = HandStrength / 100, | |||
ControlFactor = ControlLevel / 100, | |||
CardsFactor = (20 - CardsLeft) / 20, | |||
PlayThreshold = BaseThreshold * (StrengthFactor + ControlFactor + CardsFactor) / 3, | |||
evaluate_play_value(LastPlay) < PlayThreshold. | |||
evaluate_play_value({Type, Value, Cards}) -> | |||
BaseValue = case Type of | |||
?CARD_TYPE_ROCKET -> 1.0; | |||
?CARD_TYPE_BOMB -> 0.9; | |||
?CARD_TYPE_PLANE -> 0.7; | |||
?CARD_TYPE_STRAIGHT -> 0.6; | |||
?CARD_TYPE_STRAIGHT_PAIR -> 0.5; | |||
?CARD_TYPE_THREE_TWO -> 0.4; | |||
?CARD_TYPE_THREE_ONE -> 0.3; | |||
?CARD_TYPE_THREE -> 0.25; | |||
?CARD_TYPE_PAIR -> 0.2; | |||
?CARD_TYPE_SINGLE -> 0.1 | |||
end, | |||
ValueBonus = Value / ?CARD_JOKER_BIG, | |||
CardCountFactor = length(Cards) / 10, | |||
BaseValue * (1 + ValueBonus) * (1 + CardCountFactor). | |||
%% 牌型生成函数 | |||
generate_single_plays(Components) -> | |||
[{?CARD_TYPE_SINGLE, Value, [Card]} || | |||
{Value, Card} <- maps:get(singles, Components, [])]. | |||
generate_pair_plays(Components) -> | |||
[{?CARD_TYPE_PAIR, Value, Cards} || | |||
{Value, Cards} <- maps:get(pairs, Components, [])]. | |||
generate_triple_plays(Components) -> | |||
Triples = maps:get(triples, Components, []), | |||
BasicTriples = [{?CARD_TYPE_THREE, Value, Cards} || | |||
{Value, Cards} <- Triples], | |||
% 生成三带一和三带二 | |||
generate_triple_combinations(Triples, Components). | |||
generate_sequence_plays(Components) -> | |||
Sequences = maps:get(sequences, Components, []), | |||
[{?CARD_TYPE_STRAIGHT, Value, Cards} || | |||
{Value, Cards} <- Sequences]. | |||
generate_bomb_plays(Components) -> | |||
[{?CARD_TYPE_BOMB, Value, Cards} || | |||
{Value, Cards} <- maps:get(bombs, Components, [])]. | |||
generate_triple_combinations(Triples, Components) -> | |||
Singles = maps:get(singles, Components, []), | |||
Pairs = maps:get(pairs, Components, []), | |||
ThreeOne = generate_three_one(Triples, Singles), | |||
ThreeTwo = generate_three_two(Triples, Pairs), | |||
ThreeOne ++ ThreeTwo. | |||
generate_three_one(Triples, Singles) -> | |||
[{?CARD_TYPE_THREE_ONE, TripleValue, TripleCards ++ [SingleCard]} || | |||
{TripleValue, TripleCards} <- Triples, | |||
{SingleValue, SingleCard} <- Singles, | |||
SingleValue =/= TripleValue]. | |||
generate_three_two(Triples, Pairs) -> | |||
[{?CARD_TYPE_THREE_TWO, TripleValue, TripleCards ++ PairCards} || | |||
{TripleValue, TripleCards} <- Triples, | |||
{PairValue, PairCards} <- Pairs, | |||
PairValue =/= TripleValue]. | |||
%% 查找特定牌型 | |||
find_greater_plays(Cards, Type, MinValue) -> | |||
Components = analyze_components(Cards), | |||
case Type of | |||
?CARD_TYPE_SINGLE -> | |||
find_greater_singles(Components, MinValue); | |||
?CARD_TYPE_PAIR -> | |||
find_greater_pairs(Components, MinValue); | |||
?CARD_TYPE_THREE -> | |||
find_greater_triples(Components, MinValue); | |||
?CARD_TYPE_THREE_ONE -> | |||
find_greater_three_one(Components, MinValue); | |||
?CARD_TYPE_THREE_TWO -> | |||
find_greater_three_two(Components, MinValue); | |||
?CARD_TYPE_STRAIGHT -> | |||
find_greater_straight(Components, MinValue); | |||
?CARD_TYPE_STRAIGHT_PAIR -> | |||
find_greater_straight_pair(Components, MinValue); | |||
?CARD_TYPE_PLANE -> | |||
find_greater_plane(Components, MinValue); | |||
?CARD_TYPE_BOMB -> | |||
find_greater_bomb(Components, MinValue); | |||
_ -> [] | |||
end. | |||
find_greater_singles(Components, MinValue) -> | |||
[{?CARD_TYPE_SINGLE, Value, [Card]} || | |||
{Value, Card} <- maps:get(singles, Components, []), | |||
Value > MinValue]. | |||
find_greater_pairs(Components, MinValue) -> | |||
[{?CARD_TYPE_PAIR, Value, Cards} || | |||
{Value, Cards} <- maps:get(pairs, Components, []), | |||
Value > MinValue]. | |||
find_greater_triples(Components, MinValue) -> | |||
[{?CARD_TYPE_THREE, Value, Cards} || | |||
{Value, Cards} <- maps:get(triples, Components, []), | |||
Value > MinValue]. | |||
find_greater_three_one(Components, MinValue) -> | |||
Triples = [{V, C} || {V, C} <- maps:get(triples, Components, []), | |||
V > MinValue], | |||
Singles = maps:get(singles, Components, []), | |||
generate_three_one(Triples, Singles). | |||
find_greater_three_two(Components, MinValue) -> | |||
Triples = [{V, C} || {V, C} <- maps:get(triples, Components, []), | |||
V > MinValue], | |||
Pairs = maps:get(pairs, Components, []), | |||
generate_three_two(Triples, Pairs). | |||
find_greater_straight(Components, MinValue) -> | |||
Sequences = maps:get(sequences, Components, []), | |||
[{?CARD_TYPE_STRAIGHT, Value, Cards} || | |||
{Value, Cards} <- Sequences, | |||
Value > MinValue]. | |||
find_greater_straight_pair(Components, MinValue) -> | |||
Sequences = find_pair_sequences(Components), | |||
[{?CARD_TYPE_STRAIGHT_PAIR, Value, Cards} || | |||
{Value, Cards} <- Sequences, | |||
Value > MinValue]. | |||
find_greater_plane(Components, MinValue) -> | |||
Planes = find_planes(Components), | |||
[{?CARD_TYPE_PLANE, Value, Cards} || | |||
{Value, Cards} <- Planes, | |||
Value > MinValue]. | |||
find_greater_bomb(Components, MinValue) -> | |||
[{?CARD_TYPE_BOMB, Value, Cards} || | |||
{Value, Cards} <- maps:get(bombs, Components, []), | |||
Value > MinValue]. | |||
find_bomb_plays(Cards) -> | |||
Components = analyze_components(Cards), | |||
[{?CARD_TYPE_BOMB, Value, Cards} || | |||
{Value, Cards} <- maps:get(bombs, Components, [])]. | |||
find_rocket_play(Cards) -> | |||
Components = analyze_components(Cards), | |||
case find_rocket(Components) of | |||
{ok, Rocket} -> [Rocket]; | |||
_ -> [] | |||
end. | |||
find_rocket(Components) -> | |||
case {find_card(?CARD_JOKER_SMALL, Components), | |||
find_card(?CARD_JOKER_BIG, Components)} of | |||
{{ok, Small}, {ok, Big}} -> | |||
{ok, {?CARD_TYPE_ROCKET, ?CARD_JOKER_BIG, [Small, Big]}}; | |||
_ -> | |||
false | |||
end. | |||
find_card(Value, Components) -> | |||
Singles = maps:get(singles, Components, []), | |||
case lists:keyfind(Value, 1, Singles) of | |||
{Value, Card} -> {ok, Card}; | |||
_ -> false | |||
end. | |||
find_pair_sequences(Components) -> | |||
Pairs = maps:get(pairs, Components, []), | |||
find_consecutive_pairs(lists:sort(Pairs), []). | |||
find_planes(Components) -> | |||
Triples = maps:get(triples, Components, []), | |||
find_consecutive_triples(lists:sort(Triples), []). | |||
find_consecutive_pairs([], Acc) -> lists:reverse(Acc); | |||
find_consecutive_pairs([{V1, Cards1} | Rest], Acc) -> | |||
case find_consecutive_pair_sequence(V1, Cards1, Rest) of | |||
{Sequence, NewRest} when length(Sequence) >= 3 -> | |||
find_consecutive_pairs(NewRest, [{V1, Sequence} | Acc]); | |||
_ -> | |||
find_consecutive_pairs(Rest, Acc) | |||
end. | |||
find_consecutive_pair_sequence(Value, Cards, Rest) -> | |||
find_consecutive_pair_sequence(Value, Cards, Rest, [Cards]). | |||
find_consecutive_pair_sequence(Value, _, [], Acc) -> | |||
{lists:flatten(lists:reverse(Acc)), []}; | |||
find_consecutive_pair_sequence(Value, _, [{NextValue, NextCards} | Rest], Acc) | |||
when NextValue =:= Value + 1 -> | |||
find_consecutive_pair_sequence(NextValue, NextCards, Rest, [NextCards | Acc]); | |||
find_consecutive_pair_sequence(_, _, Rest, Acc) -> | |||
{lists:flatten(lists:reverse(Acc)), Rest}. | |||
find_consecutive_triples([], Acc) -> lists:reverse(Acc); | |||
find_consecutive_triples([{V1, Cards1} | Rest], Acc) -> | |||
case find_consecutive_triple_sequence(V1, Cards1, Rest) of | |||
{Sequence, NewRest} when length(Sequence) >= 6 -> | |||
find_consecutive_triples(NewRest, [{V1, Sequence} | Acc]); | |||
_ -> | |||
find_consecutive_triples(Rest, Acc) | |||
end. | |||
find_consecutive_triple_sequence(Value, Cards, Rest) -> | |||
find_consecutive_triple_sequence(Value, Cards, Rest, [Cards]). | |||
find_consecutive_triple_sequence(Value, _, [], Acc) -> | |||
{lists:flatten(lists:reverse(Acc)), []}; | |||
find_consecutive_triple_sequence(Value, _, [{NextValue, NextCards} | Rest], Acc) | |||
when NextValue =:= Value + 1 -> | |||
find_consecutive_triple_sequence(NextValue, NextCards, Rest, [NextCards | Acc]); | |||
find_consecutive_triple_sequence(_, _, Rest, Acc) -> | |||
{lists:flatten(lists:reverse(Acc)), Rest}. | |||
%% 计算早期、中期和末期的节奏分数 | |||
calculate_early_tempo({Type, Value, _}, _Context) -> | |||
case Type of | |||
?CARD_TYPE_SINGLE when Value < ?CARD_2 -> 0.8; | |||
?CARD_TYPE_PAIR when Value < ?CARD_2 -> 0.7; | |||
?CARD_TYPE_STRAIGHT -> 0.9; | |||
?CARD_TYPE_STRAIGHT_PAIR -> 0.85; | |||
_ -> 0.5 | |||
end. | |||
calculate_mid_tempo({Type, Value, _}, _Context) -> | |||
case Type of | |||
?CARD_TYPE_THREE_ONE -> 0.8; | |||
?CARD_TYPE_THREE_TWO -> 0.85; | |||
?CARD_TYPE_PLANE -> 0.9; | |||
?CARD_TYPE_BOMB -> 0.7; | |||
_ -> 0.6 | |||
end. | |||
calculate_end_tempo({Type, Value, _}, Context) -> | |||
CardsLeft = Context#play_context.cards_remaining, | |||
case Type of | |||
?CARD_TYPE_BOMB -> 0.9; | |||
?CARD_TYPE_ROCKET -> 1.0; | |||
_ when CardsLeft =< 4 -> 0.95; | |||
_ -> 0.7 | |||
end. | |||
%% 过滤候选出牌 | |||
filter_candidates(Plays, Context) -> | |||
case Context#play_context.game_stage of | |||
early_game -> | |||
filter_early_game_plays(Plays, Context); | |||
mid_game -> | |||
filter_mid_game_plays(Plays, Context); | |||
end_game -> | |||
filter_end_game_plays(Plays, Context) | |||
end. | |||
filter_early_game_plays(Plays, _Context) -> | |||
% 早期游戏倾向于出小牌,保留炸弹 | |||
[Play || {Type, Value, _} = Play <- Plays, | |||
Type =/= ?CARD_TYPE_BOMB orelse Value >= ?CARD_2]. | |||
filter_mid_game_plays(Plays, Context) -> | |||
% 中期游戏根据局势决定是否使用炸弹 | |||
case Context#play_context.control_factor < 0.5 of | |||
true -> Plays; | |||
false -> | |||
[Play || {Type, _, _} = Play <- Plays, | |||
Type =/= ?CARD_TYPE_BOMB] | |||
end. | |||
filter_end_game_plays(Plays, _Context) -> | |||
% 末期游戏可以使用任何牌型 | |||
Plays. | |||
%% 对手牌进行分组 | |||
group_cards_by_value(Cards) -> | |||
lists:foldl(fun(Card, Acc) -> | |||
{Value, _} = Card, | |||
maps:update_with(Value, | |||
fun(List) -> [Card|List] end, | |||
[Card], | |||
Acc) | |||
end, maps:new(), Cards). |
@ -1,58 +0,0 @@ | |||
-module(ai_test). | |||
-export([run_test/0]). | |||
run_test() -> | |||
% 启动所有必要的服务 | |||
{ok, DL} = deep_learning:start_link(), | |||
{ok, PC} = parallel_compute:start_link(), | |||
{ok, PM} = performance_monitor:start_link(), | |||
{ok, VS} = visualization:start_link(), | |||
% 创建测试数据 | |||
TestData = create_test_data(), | |||
% 训练网络 | |||
{ok, Network} = deep_learning:train_network(test_network, TestData), | |||
% 进行预测 | |||
TestInput = prepare_test_input(), | |||
{ok, Prediction} = deep_learning:predict(test_network, TestInput), | |||
% 监控性能 | |||
{ok, MonitorId} = performance_monitor:start_monitoring(test_network), | |||
% 等待一些时间收集数据 | |||
timer:sleep(5000), | |||
% 获取性能数据 | |||
{ok, Metrics} = performance_monitor:get_metrics(MonitorId), | |||
% 创建可视化 | |||
{ok, ChartId} = visualization:create_chart(line_chart, Metrics), | |||
% 导出结果 | |||
{ok, Report} = performance_monitor:generate_report(MonitorId), | |||
{ok, Chart} = visualization:export_chart(ChartId, png), | |||
% 清理资源 | |||
ok = performance_monitor:stop_monitoring(MonitorId), | |||
% 返回测试结果 | |||
#{ | |||
prediction => Prediction, | |||
metrics => Metrics, | |||
report => Report, | |||
chart => Chart | |||
}. | |||
% 辅助函数 | |||
create_test_data() -> | |||
[ | |||
{[1,2,3], [4]}, | |||
{[2,3,4], [5]}, | |||
{[3,4,5], [6]}, | |||
{[4,5,6], [7]} | |||
]. | |||
prepare_test_input() -> | |||
[5,6,7]. |
@ -0,0 +1,297 @@ | |||
-module(card_checker). | |||
-include("card_types.hrl"). | |||
%% API exports | |||
-export([ | |||
check_card_type/1, % 检查牌型 | |||
compare_cards/2, % 比较两手牌 | |||
is_valid_play/2, % 验证出牌是否合法 | |||
format_cards/1, % 格式化牌的显示 | |||
validate_cards/1 % 验证牌的合法性 | |||
]). | |||
-type card() :: {integer(), atom()}. | |||
-type cards() :: [card()]. | |||
-type card_type() :: integer(). | |||
-type card_value() :: integer(). | |||
%% API 函数实现 | |||
%% @doc 检查一组牌的类型 | |||
-spec check_card_type(cards()) -> {ok, card_type(), card_value()} | {error, invalid_type}. | |||
check_card_type(Cards) when length(Cards) > 0 -> | |||
SortedCards = sort_cards(Cards), | |||
case identify_type(SortedCards) of | |||
{ok, _Type, _Value} = Result -> Result; | |||
Error -> Error | |||
end; | |||
check_card_type([]) -> | |||
{error, invalid_type}. | |||
%% @doc 比较两手牌的大小 | |||
-spec compare_cards(cards(), cards()) -> greater | lesser | invalid. | |||
compare_cards(Cards1, Cards2) -> | |||
case {check_card_type(Cards1), check_card_type(Cards2)} of | |||
{{ok, Type1, Value1}, {ok, Type2, Value2}} -> | |||
compare_types_and_values(Type1, Value1, Type2, Value2); | |||
_ -> | |||
invalid | |||
end. | |||
%% @doc 验证当前出牌是否合法(相对于上一手牌) | |||
-spec is_valid_play(cards(), cards() | undefined) -> boolean(). | |||
is_valid_play(_NewCards, undefined) -> | |||
true; | |||
is_valid_play(NewCards, LastCards) -> | |||
case {check_card_type(NewCards), check_card_type(LastCards)} of | |||
{{ok, Type1, Value1}, {ok, Type2, Value2}} -> | |||
can_beat(Type1, Value1, Type2, Value2); | |||
_ -> | |||
false | |||
end. | |||
%% @doc 格式化牌的显示 | |||
-spec format_cards(cards()) -> string(). | |||
format_cards(Cards) -> | |||
lists:map(fun format_card/1, Cards). | |||
%% @doc 验证牌的合法性 | |||
-spec validate_cards(cards()) -> boolean(). | |||
validate_cards(Cards) -> | |||
is_valid_card_list(Cards) andalso | |||
no_duplicate_cards(Cards) andalso | |||
all_cards_valid(Cards). | |||
%% 内部函数 | |||
%% @private 识别牌型 | |||
identify_type(Cards) -> | |||
case length(Cards) of | |||
1 -> {ok, ?CARD_TYPE_SINGLE, get_card_value(hd(Cards))}; | |||
2 -> check_pair_or_rocket(Cards); | |||
3 -> check_three(Cards); | |||
4 -> check_four_or_three_one(Cards); | |||
5 -> check_three_two(Cards); | |||
_ -> check_sequence_types(Cards) | |||
end. | |||
%% @private 检查对子或火箭 | |||
check_pair_or_rocket([{V1, S1}, {V2, S2}] = Cards) -> | |||
if | |||
V1 =:= V2 -> {ok, ?CARD_TYPE_PAIR, V1}; | |||
V1 =:= ?CARD_JOKER_BIG, V2 =:= ?CARD_JOKER_SMALL -> | |||
{ok, ?CARD_TYPE_ROCKET, ?CARD_JOKER_BIG}; | |||
true -> {error, invalid_type} | |||
end. | |||
%% @private 检查三张 | |||
check_three([{V, _}, {V, _}, {V, _}]) -> | |||
{ok, ?CARD_TYPE_THREE, V}; | |||
check_three(_) -> | |||
{error, invalid_type}. | |||
%% @private 检查四张或三带一 | |||
check_four_or_three_one(Cards) -> | |||
Values = [V || {V, _} <- Cards], | |||
case count_values(Values) of | |||
[{V, 4}] -> | |||
{ok, ?CARD_TYPE_BOMB, V}; | |||
[{V, 3}, {_, 1}] -> | |||
{ok, ?CARD_TYPE_THREE_ONE, V}; | |||
[{_, 1}, {V, 3}] -> | |||
{ok, ?CARD_TYPE_THREE_ONE, V}; | |||
_ -> | |||
{error, invalid_type} | |||
end. | |||
%% @private 检查三带二 | |||
check_three_two(Cards) -> | |||
Values = [V || {V, _} <- Cards], | |||
case count_values(Values) of | |||
[{V, 3}, {_, 2}] -> {ok, ?CARD_TYPE_THREE_TWO, V}; | |||
[{_, 2}, {V, 3}] -> {ok, ?CARD_TYPE_THREE_TWO, V}; | |||
_ -> {error, invalid_type} | |||
end. | |||
%% @private 检查顺子类型 | |||
check_sequence_types(Cards) -> | |||
case length(Cards) of | |||
L when L >= 5 -> | |||
Values = [V || {V, _} <- Cards], | |||
cond_check_sequence_types(Values, Cards); | |||
_ -> | |||
{error, invalid_type} | |||
end. | |||
%% @private 条件检查顺子类型 | |||
cond_check_sequence_types(Values, Cards) -> | |||
case is_straight(Values) of | |||
true -> | |||
{ok, ?CARD_TYPE_STRAIGHT, lists:max(Values)}; | |||
false -> | |||
case is_straight_pairs(Values) of | |||
true -> | |||
{ok, ?CARD_TYPE_STRAIGHT_PAIR, lists:max(Values)}; | |||
false -> | |||
check_plane_types(Cards) | |||
end | |||
end. | |||
%% @private 检查飞机类型 | |||
check_plane_types(Cards) -> | |||
Values = [V || {V, _} <- Cards], | |||
case check_plane_pattern(Values) of | |||
{ok, MainValue, WithWings} -> | |||
PlaneType = case WithWings of | |||
true -> ?CARD_TYPE_PLANE_WINGS; | |||
false -> ?CARD_TYPE_PLANE | |||
end, | |||
{ok, PlaneType, MainValue}; | |||
false -> | |||
{error, invalid_type} | |||
end. | |||
%% @private 检查是否是顺子 | |||
is_straight(Values) -> | |||
SortedVals = lists:sort(Values), | |||
length(SortedVals) >= 5 andalso | |||
lists:max(SortedVals) < ?CARD_2 andalso | |||
is_consecutive(SortedVals). | |||
%% @private 检查是否是连对 | |||
is_straight_pairs(Values) -> | |||
case count_values(Values) of | |||
Pairs when length(Pairs) >= 3 -> | |||
PairValues = [V || {V, 2} <- Pairs], | |||
length(PairValues) * 2 =:= length(Values) andalso | |||
lists:max(PairValues) < ?CARD_2 andalso | |||
is_consecutive(lists:sort(PairValues)); | |||
_ -> | |||
false | |||
end. | |||
%% @private 检查飞机模式 | |||
check_plane_pattern(Values) -> | |||
Counts = count_values(Values), | |||
ThreeCounts = [{V, C} || {V, C} <- Counts, C =:= 3], | |||
case length(ThreeCounts) >= 2 of | |||
true -> | |||
ThreeValues = [V || {V, _} <- ThreeCounts], | |||
SortedThrees = lists:sort(ThreeValues), | |||
case is_consecutive(SortedThrees) of | |||
true -> | |||
MainValue = lists:max(SortedThrees), | |||
HasWings = length(Values) > length(ThreeValues) * 3, | |||
{ok, MainValue, HasWings}; | |||
false -> | |||
false | |||
end; | |||
false -> | |||
false | |||
end. | |||
%% @private 检查是否连续 | |||
is_consecutive([]) -> true; | |||
is_consecutive([_]) -> true; | |||
is_consecutive([A,B|Rest]) -> | |||
case B - A of | |||
1 -> is_consecutive([B|Rest]); | |||
_ -> false | |||
end. | |||
%% @private 统计值的出现次数 | |||
count_values(Values) -> | |||
Sorted = lists:sort(Values), | |||
count_values(Sorted, [], 1). | |||
count_values([], Acc, _) -> | |||
lists:reverse(Acc); | |||
count_values([V], Acc, Count) -> | |||
lists:reverse([{V, Count}|Acc]); | |||
count_values([V,V|Rest], Acc, Count) -> | |||
count_values([V|Rest], Acc, Count + 1); | |||
count_values([V1,V2|Rest], Acc, Count) -> | |||
count_values([V2|Rest], [{V1, Count}|Acc], 1). | |||
%% @private 比较类型和值 | |||
compare_types_and_values(Type, Value, Type, Value2) -> | |||
if | |||
Value > Value2 -> greater; | |||
true -> lesser | |||
end; | |||
compare_types_and_values(?CARD_TYPE_ROCKET, _, _, _) -> | |||
greater; | |||
compare_types_and_values(_, _, ?CARD_TYPE_ROCKET, _) -> | |||
lesser; | |||
compare_types_and_values(?CARD_TYPE_BOMB, Value1, Type2, Value2) -> | |||
case Type2 of | |||
?CARD_TYPE_BOMB when Value1 > Value2 -> greater; | |||
?CARD_TYPE_BOMB -> lesser; | |||
_ -> greater | |||
end; | |||
compare_types_and_values(Type1, _, ?CARD_TYPE_BOMB, _) -> | |||
lesser; | |||
compare_types_and_values(_, _, _, _) -> | |||
invalid. | |||
%% @private 检查是否能打过上一手牌 | |||
can_beat(Type1, Value1, Type2, Value2) -> | |||
case {Type1, Type2} of | |||
{Same, Same} -> Value1 > Value2; | |||
{?CARD_TYPE_ROCKET, _} -> true; | |||
{?CARD_TYPE_BOMB, OtherType} when OtherType =/= ?CARD_TYPE_ROCKET -> true; | |||
_ -> false | |||
end. | |||
%% @private 格式化单张牌 | |||
format_card({Value, Suit}) -> | |||
SuitStr = case Suit of | |||
hearts -> "♥"; | |||
diamonds -> "♦"; | |||
clubs -> "♣"; | |||
spades -> "♠"; | |||
joker -> "J" | |||
end, | |||
ValueStr = case Value of | |||
?CARD_JOKER_BIG -> "BJ"; | |||
?CARD_JOKER_SMALL -> "SJ"; | |||
?CARD_2 -> "2"; | |||
?CARD_A -> "A"; | |||
?CARD_K -> "K"; | |||
?CARD_Q -> "Q"; | |||
?CARD_J -> "J"; | |||
10 -> "10"; | |||
N when N >= 3, N =< 9 -> integer_to_list(N) | |||
end, | |||
SuitStr ++ ValueStr. | |||
%% @private 检查是否是有效的牌列表 | |||
is_valid_card_list(Cards) -> | |||
is_list(Cards) andalso length(Cards) > 0. | |||
%% @private 检查是否有重复的牌 | |||
no_duplicate_cards(Cards) -> | |||
length(lists:usort(Cards)) =:= length(Cards). | |||
%% @private 检查所有牌是否合法 | |||
all_cards_valid(Cards) -> | |||
lists:all(fun is_valid_card/1, Cards). | |||
%% @private 检查单张牌是否合法 | |||
is_valid_card({Value, Suit}) -> | |||
(Value >= ?CARD_3 andalso Value =< ?CARD_2 andalso | |||
lists:member(Suit, [hearts, diamonds, clubs, spades])) orelse | |||
(Value >= ?CARD_JOKER_SMALL andalso Value =< ?CARD_JOKER_BIG andalso | |||
Suit =:= joker). | |||
%% @private 排序牌 | |||
sort_cards(Cards) -> | |||
lists:sort(fun({V1, S1}, {V2, S2}) -> | |||
if | |||
V1 =:= V2 -> S1 =< S2; | |||
true -> V1 =< V2 | |||
end | |||
end, Cards). | |||
%% @private 获取牌的值 | |||
get_card_value({Value, _}) -> Value. |
@ -1,53 +0,0 @@ | |||
-module(decision_engine). | |||
-export([make_decision/3, evaluate_options/2, calculate_win_probability/2]). | |||
make_decision(GameState, AIState, Options) -> | |||
% 深度评估每个选项 | |||
EvaluatedOptions = deep_evaluate_options(Options, GameState, AIState), | |||
% 计算胜率 | |||
WinProbabilities = calculate_win_probabilities(EvaluatedOptions, GameState), | |||
% 风险评估 | |||
RiskAnalysis = analyze_risks(EvaluatedOptions, GameState), | |||
% 综合决策 | |||
BestOption = select_optimal_option(EvaluatedOptions, WinProbabilities, RiskAnalysis), | |||
% 应用最终决策 | |||
apply_decision(BestOption, GameState, AIState). | |||
deep_evaluate_options(Options, GameState, AIState) -> | |||
lists:map( | |||
fun(Option) -> | |||
% 深度搜索评估 | |||
SearchResult = monte_carlo_search(Option, GameState, 1000), | |||
% 策略评估 | |||
StrategyScore = strategy_optimizer:evaluate_strategy(Option, GameState), | |||
% 对手反应预测 | |||
OpponentReaction = opponent_modeling:predict_play(AIState#ai_state.opponent_model, GameState), | |||
% 综合评分 | |||
{Option, calculate_comprehensive_score(SearchResult, StrategyScore, OpponentReaction)} | |||
end, | |||
Options | |||
). | |||
calculate_win_probability(Option, GameState) -> | |||
% 基于当前局势分析 | |||
SituationScore = analyze_situation_score(GameState), | |||
% 基于牌型分析 | |||
PatternScore = analyze_pattern_strength(Option), | |||
% 基于对手模型 | |||
OpponentScore = analyze_opponent_factors(GameState), | |||
% 计算综合胜率 | |||
calculate_combined_probability([ | |||
{SituationScore, 0.4}, | |||
{PatternScore, 0.3}, | |||
{OpponentScore, 0.3} | |||
]). |
@ -1,104 +0,0 @@ | |||
-module(deep_learning). | |||
-behaviour(gen_server). | |||
-export([start_link/0, init/1, handle_call/3, handle_cast/2, handle_info/2, terminate/2, code_change/3]). | |||
-export([train_network/2, predict/2, update_network/2, get_network_stats/1]). | |||
-record(state, { | |||
networks = #{}, % Map: NetworkName -> NetworkData | |||
training_queue = [], % 训练队列 | |||
batch_size = 32, % 批次大小 | |||
learning_rate = 0.001 % 学习率 | |||
}). | |||
-record(network, { | |||
layers = [], % 网络层结构 | |||
weights = #{}, % 权重 | |||
biases = #{}, % 偏置 | |||
activation = relu, % 激活函数 | |||
optimizer = adam, % 优化器 | |||
loss_history = [], % 损失历史 | |||
accuracy_history = [] % 准确率历史 | |||
}). | |||
%% API | |||
start_link() -> | |||
gen_server:start_link({local, ?MODULE}, ?MODULE, [], []). | |||
train_network(NetworkName, TrainingData) -> | |||
gen_server:call(?MODULE, {train, NetworkName, TrainingData}). | |||
predict(NetworkName, Input) -> | |||
gen_server:call(?MODULE, {predict, NetworkName, Input}). | |||
update_network(NetworkName, Gradients) -> | |||
gen_server:cast(?MODULE, {update, NetworkName, Gradients}). | |||
get_network_stats(NetworkName) -> | |||
gen_server:call(?MODULE, {get_stats, NetworkName}). | |||
%% 内部函数 | |||
initialize_network(LayerSizes) -> | |||
Layers = create_layers(LayerSizes), | |||
Weights = initialize_weights(Layers), | |||
Biases = initialize_biases(Layers), | |||
#network{ | |||
layers = Layers, | |||
weights = Weights, | |||
biases = Biases | |||
}. | |||
create_layers(Sizes) -> | |||
lists:map(fun(Size) -> | |||
#{size => Size, type => dense} | |||
end, Sizes). | |||
initialize_weights(Layers) -> | |||
lists:foldl( | |||
fun(Layer, Acc) -> | |||
Size = maps:get(size, Layer), | |||
W = random_matrix(Size, Size), | |||
maps:put(Size, W, Acc) | |||
end, | |||
#{}, | |||
Layers | |||
). | |||
random_matrix(Rows, Cols) -> | |||
matrix:new(Rows, Cols, fun() -> rand:uniform() - 0.5 end). | |||
forward_propagation(Input, Network) -> | |||
#network{layers = Layers, weights = Weights, biases = Biases} = Network, | |||
lists:foldl( | |||
fun(Layer, Acc) -> | |||
W = maps:get(maps:get(size, Layer), Weights), | |||
B = maps:get(maps:get(size, Layer), Biases), | |||
Z = matrix:add(matrix:multiply(W, Acc), B), | |||
activate(Z, Network#network.activation) | |||
end, | |||
Input, | |||
Layers | |||
). | |||
backward_propagation(Network, Input, Target) -> | |||
% 实现反向传播算法 | |||
{Gradients, Loss} = calculate_gradients(Network, Input, Target), | |||
{Network#network{ | |||
weights = update_weights(Network#network.weights, Gradients), | |||
loss_history = [Loss | Network#network.loss_history] | |||
}, Loss}. | |||
activate(Z, relu) -> | |||
matrix:map(fun(X) -> max(0, X) end, Z); | |||
activate(Z, sigmoid) -> | |||
matrix:map(fun(X) -> 1 / (1 + math:exp(-X)) end, Z); | |||
activate(Z, tanh) -> | |||
matrix:map(fun(X) -> math:tanh(X) end, Z). | |||
optimize(Network, Gradients, adam) -> | |||
% 实现Adam优化器 | |||
update_adam(Network, Gradients); | |||
optimize(Network, Gradients, sgd) -> | |||
% 实现随机梯度下降 | |||
update_sgd(Network, Gradients). |
@ -0,0 +1,277 @@ | |||
-module(doudizhu_ai). | |||
-behaviour(gen_server). | |||
-export([ | |||
start_link/1, | |||
make_decision/2, | |||
analyze_cards/1, | |||
update_game_state/2 | |||
]). | |||
-export([ | |||
init/1, | |||
handle_call/3, | |||
handle_cast/2, | |||
handle_info/2, | |||
terminate/2, | |||
code_change/3 | |||
]). | |||
-include("card_types.hrl"). | |||
-record(state, { | |||
player_id, % AI 玩家ID | |||
role, % dizhu | nongmin (地主或农民) | |||
known_cards = [], % 已知的牌 | |||
hand_cards = [], % 手牌 | |||
played_cards = [], % 已打出的牌 | |||
other_players = [], % 其他玩家信息 | |||
game_history = [], % 游戏历史 | |||
strategy_cache = #{} % 策略缓存 | |||
}). | |||
%% API 函数 | |||
start_link(PlayerId) -> | |||
gen_server:start_link({local, ?MODULE}, ?MODULE, [PlayerId], []). | |||
make_decision(GameState, Options) -> | |||
gen_server:call(?MODULE, {make_decision, GameState, Options}). | |||
analyze_cards(Cards) -> | |||
gen_server:call(?MODULE, {analyze_cards, Cards}). | |||
update_game_state(Event, Data) -> | |||
gen_server:cast(?MODULE, {update_game_state, Event, Data}). | |||
%% Callback 函数 | |||
init([PlayerId]) -> | |||
{ok, #state{player_id = PlayerId}}. | |||
handle_call({make_decision, GameState, Options}, _From, State) -> | |||
{Decision, NewState} = calculate_best_move(GameState, Options, State), | |||
{reply, Decision, NewState}; | |||
handle_call({analyze_cards, Cards}, _From, State) -> | |||
Analysis = perform_cards_analysis(Cards, State), | |||
{reply, Analysis, State}; | |||
handle_call(_Request, _From, State) -> | |||
{reply, ok, State}. | |||
handle_cast({update_game_state, Event, Data}, State) -> | |||
NewState = update_state(Event, Data, State), | |||
{noreply, NewState}; | |||
handle_cast(_Msg, State) -> | |||
{noreply, State}. | |||
handle_info(_Info, State) -> | |||
{noreply, State}. | |||
terminate(_Reason, _State) -> | |||
ok. | |||
code_change(_OldVsn, State, _Extra) -> | |||
{ok, State}. | |||
%% 内部函数 | |||
calculate_best_move(GameState, Options, State) -> | |||
% 1. 分析当前局势 | |||
Situation = analyze_situation(GameState, State), | |||
% 2. 生成可能的动作 | |||
PossibleMoves = generate_possible_moves(GameState, Options, State), | |||
% 3. 评估每个动作 | |||
ScoredMoves = evaluate_moves(PossibleMoves, Situation, State), | |||
% 4. 选择最佳动作 | |||
{BestMove, Score} = select_best_move(ScoredMoves), | |||
% 5. 更新状态 | |||
NewState = update_strategy_state(BestMove, Score, State), | |||
{BestMove, NewState}. | |||
analyze_situation(GameState, State) -> | |||
#situation{ | |||
game_stage = determine_game_stage(GameState), | |||
hand_strength = evaluate_hand_strength(State#state.hand_cards), | |||
control_level = calculate_control_level(GameState, State), | |||
winning_probability = estimate_winning_probability(GameState, State) | |||
}. | |||
%% 确定游戏阶段 | |||
determine_game_stage(GameState) -> | |||
CardsLeft = count_remaining_cards(GameState), | |||
cond do | |||
CardsLeft > 15 -> early_game; | |||
CardsLeft > 8 -> mid_game; | |||
true -> end_game | |||
end. | |||
%% 评估手牌强度 | |||
evaluate_hand_strength(Cards) -> | |||
% 分析手牌组成 | |||
Components = analyze_card_components(Cards), | |||
% 计算基础分数 | |||
BaseScore = calculate_base_score(Components), | |||
% 计算组合价值 | |||
ComboScore = calculate_combo_value(Components), | |||
% 计算控制牌价值 | |||
ControlScore = calculate_control_value(Components), | |||
% 返回综合评分 | |||
#hand_strength{ | |||
base_score = BaseScore, | |||
combo_score = ComboScore, | |||
control_score = ControlScore, | |||
total_score = BaseScore + ComboScore + ControlScore | |||
}. | |||
%% 分析牌型组件 | |||
analyze_card_components(Cards) -> | |||
% 对牌进行分组 | |||
Groups = group_cards_by_value(Cards), | |||
% 识别各种牌型 | |||
Singles = find_singles(Groups), | |||
Pairs = find_pairs(Groups), | |||
Triples = find_triples(Groups), | |||
Bombs = find_bombs(Groups), | |||
Sequences = find_sequences(Groups), | |||
% 返回组件分析结果 | |||
#components{ | |||
singles = Singles, | |||
pairs = Pairs, | |||
triples = Triples, | |||
bombs = Bombs, | |||
sequences = Sequences | |||
}. | |||
%% 生成可能的出牌选择 | |||
generate_possible_moves(GameState, Options, State) -> | |||
LastPlay = get_last_play(GameState), | |||
HandCards = State#state.hand_cards, | |||
case LastPlay of | |||
none -> | |||
generate_leading_moves(HandCards); | |||
{Type, Value, _Cards} -> | |||
generate_following_moves(HandCards, Type, Value) | |||
end. | |||
%% 评估每个可能的动作 | |||
evaluate_moves(Moves, Situation, State) -> | |||
lists:map(fun(Move) -> | |||
Score = calculate_move_score(Move, Situation, State), | |||
{Move, Score} | |||
end, Moves). | |||
%% 计算单个动作的得分 | |||
calculate_move_score(Move, Situation, State) -> | |||
BaseScore = calculate_base_move_score(Move), | |||
PositionScore = calculate_position_score(Move, Situation), | |||
StrategyScore = calculate_strategy_score(Move, Situation, State), | |||
RiskScore = calculate_risk_score(Move, Situation, State), | |||
BaseScore * 0.4 + | |||
PositionScore * 0.2 + | |||
StrategyScore * 0.3 + | |||
RiskScore * 0.1. | |||
%% 选择最佳动作 | |||
select_best_move(ScoredMoves) -> | |||
lists:foldl(fun | |||
({Move, Score}, {BestMove, BestScore}) when Score > BestScore -> | |||
{Move, Score}; | |||
(_, Current) -> | |||
Current | |||
end, {pass, 0}, ScoredMoves). | |||
%% 更新策略状态 | |||
update_strategy_state(Move, Score, State) -> | |||
NewCache = update_strategy_cache(Move, Score, State#state.strategy_cache), | |||
State#state{strategy_cache = NewCache}. | |||
%% 生成开局动作 | |||
generate_leading_moves(Cards) -> | |||
% 基于手牌生成所有可能的出牌选择 | |||
Components = analyze_card_components(Cards), | |||
% 生成不同类型的出牌 | |||
Singles = generate_single_moves(Components), | |||
Pairs = generate_pair_moves(Components), | |||
Triples = generate_triple_moves(Components), | |||
Sequences = generate_sequence_moves(Components), | |||
Bombs = generate_bomb_moves(Components), | |||
% 合并所有可能的出牌 | |||
Singles ++ Pairs ++ Triples ++ Sequences ++ Bombs. | |||
%% 生成跟牌动作 | |||
generate_following_moves(Cards, Type, MinValue) -> | |||
% 根据上家出牌类型生成可能的跟牌 | |||
ValidMoves = find_valid_moves(Cards, Type, MinValue), | |||
% 添加炸弹和火箭选项 | |||
SpecialMoves = find_special_moves(Cards), | |||
ValidMoves ++ SpecialMoves. | |||
%% 估算胜率 | |||
estimate_winning_probability(GameState, State) -> | |||
% 考虑多个因素计算胜率 | |||
HandStrength = evaluate_hand_strength(State#state.hand_cards), | |||
Position = evaluate_position(GameState), | |||
RemainingCards = analyze_remaining_cards(GameState, State), | |||
ControlFactor = calculate_control_factor(GameState, State), | |||
BaseProb = calculate_base_probability(HandStrength, Position), | |||
AdjustedProb = adjust_probability(BaseProb, RemainingCards, ControlFactor), | |||
clamp(AdjustedProb, 0.0, 1.0). | |||
%% 计算控制因子 | |||
calculate_control_factor(GameState, State) -> | |||
ControlCards = count_control_cards(State#state.hand_cards), | |||
TotalCards = count_total_cards(GameState), | |||
RemainingCards = count_remaining_cards(GameState), | |||
ControlRatio = ControlCards / max(1, RemainingCards), | |||
PositionBonus = calculate_position_bonus(GameState, State), | |||
ControlRatio * PositionBonus. | |||
%% 分析剩余牌 | |||
analyze_remaining_cards(GameState, State) -> | |||
PlayedCards = get_played_cards(GameState), | |||
KnownCards = State#state.known_cards, | |||
AllCards = generate_full_deck(), | |||
RemainingCards = AllCards -- (PlayedCards ++ KnownCards), | |||
analyze_card_distribution(RemainingCards). | |||
%% 更新游戏状态 | |||
update_state(Event, Data, State) -> | |||
case Event of | |||
play_cards -> | |||
update_after_play(Data, State); | |||
receive_cards -> | |||
update_after_receive(Data, State); | |||
game_over -> | |||
update_after_game_over(Data, State); | |||
_ -> | |||
State | |||
end. | |||
%% 实用函数 | |||
clamp(Value, Min, Max) -> | |||
min(Max, max(Min, Value)). | |||
%% 计算基础概率 | |||
calculate_base_probability(HandStrength, Position) -> | |||
BaseProb = HandStrength#hand_strength.total_score / 100, | |||
PositionMod = case Position of | |||
first -> 1.2; | |||
middle -> 1.0; | |||
last -> 0.8 | |||
end, | |||
BaseProb * PositionMod. | |||
%% 调整概率 | |||
adjust_probability(BaseProb, RemainingCards, ControlFactor) -> | |||
RemainingMod = calculate_remaining_modifier(RemainingCards), | |||
ControlMod = calculate_control_modifier(ControlFactor), | |||
BaseProb * RemainingMod * ControlMod. |
@ -0,0 +1,399 @@ | |||
-module(doudizhu_ai_strategy). | |||
-export([ | |||
evaluate_situation/2, | |||
choose_strategy/2, | |||
execute_strategy/3, | |||
calculate_move_score/3, | |||
analyze_hand_value/1 | |||
]). | |||
-include("card_types.hrl"). | |||
-record(strategy_context, { | |||
game_stage, % early_game | mid_game | end_game | |||
role, % dizhu | nongmin | |||
hand_strength, % 手牌强度评估 | |||
control_level, % 控制能力评估 | |||
winning_prob, % 胜率估计 | |||
cards_remaining, % 剩余手牌数 | |||
opponent_cards % 对手剩余牌数 | |||
}). | |||
-record(hand_value, { | |||
singles = [], % 单牌 | |||
pairs = [], % 对子 | |||
triples = [], % 三张 | |||
sequences = [], % 顺子 | |||
bombs = [], % 炸弹 | |||
rockets = [] % 火箭 | |||
}). | |||
%% 评估局势 | |||
evaluate_situation(GameState, PlayerState) -> | |||
HandCards = PlayerState#state.hand_cards, | |||
OpponentCards = get_opponent_cards(GameState), | |||
Context = #strategy_context{ | |||
game_stage = determine_game_stage(GameState), | |||
role = PlayerState#state.role, | |||
hand_strength = evaluate_hand_strength(HandCards), | |||
control_level = calculate_control_level(GameState, PlayerState), | |||
winning_prob = estimate_winning_probability(GameState, PlayerState), | |||
cards_remaining = length(HandCards), | |||
opponent_cards = OpponentCards | |||
}. | |||
%% 选择策略 | |||
choose_strategy(Context, Options) -> | |||
BaseStrategy = choose_base_strategy(Context), | |||
RefineStrategy = refine_strategy(BaseStrategy, Context, Options), | |||
adjust_strategy_for_endgame(RefineStrategy, Context). | |||
%% 执行策略 | |||
execute_strategy(Strategy, GameState, PlayerState) -> | |||
HandCards = PlayerState#state.hand_cards, | |||
LastPlay = get_last_play(GameState), | |||
case Strategy of | |||
aggressive -> | |||
execute_aggressive_strategy(HandCards, LastPlay, GameState); | |||
conservative -> | |||
execute_conservative_strategy(HandCards, LastPlay, GameState); | |||
control -> | |||
execute_control_strategy(HandCards, LastPlay, GameState); | |||
explosive -> | |||
execute_explosive_strategy(HandCards, LastPlay, GameState) | |||
end. | |||
%% 内部函数 | |||
%% 选择基础策略 | |||
choose_base_strategy(Context) -> | |||
case {Context#strategy_context.game_stage, Context#strategy_context.role} of | |||
{early_game, dizhu} when Context#strategy_context.hand_strength >= 0.7 -> | |||
aggressive; | |||
{early_game, nongmin} when Context#strategy_context.hand_strength >= 0.8 -> | |||
control; | |||
{mid_game, _} when Context#strategy_context.winning_prob >= 0.7 -> | |||
control; | |||
{end_game, _} when Context#strategy_context.cards_remaining =< 5 -> | |||
explosive; | |||
_ -> | |||
conservative | |||
end. | |||
%% 细化策略 | |||
refine_strategy(BaseStrategy, Context, Options) -> | |||
case {BaseStrategy, Context#strategy_context.control_level} of | |||
{conservative, Control} when Control >= 0.8 -> | |||
control; | |||
{aggressive, Control} when Control =< 0.3 -> | |||
conservative; | |||
{Strategy, _} -> | |||
Strategy | |||
end. | |||
%% 调整终盘策略 | |||
adjust_strategy_for_endgame(Strategy, Context) -> | |||
case Context#strategy_context.cards_remaining of | |||
N when N =< 3 -> | |||
case Context#strategy_context.winning_prob of | |||
P when P >= 0.8 -> aggressive; | |||
P when P =< 0.2 -> explosive; | |||
_ -> Strategy | |||
end; | |||
_ -> | |||
Strategy | |||
end. | |||
%% 执行进攻策略 | |||
execute_aggressive_strategy(HandCards, LastPlay, GameState) -> | |||
HandValue = analyze_hand_value(HandCards), | |||
case LastPlay of | |||
none -> | |||
choose_aggressive_lead(HandValue); | |||
Play -> | |||
choose_aggressive_follow(HandValue, Play) | |||
end. | |||
%% 执行保守策略 | |||
execute_conservative_strategy(HandCards, LastPlay, GameState) -> | |||
HandValue = analyze_hand_value(HandCards), | |||
case LastPlay of | |||
none -> | |||
choose_conservative_lead(HandValue); | |||
Play -> | |||
choose_conservative_follow(HandValue, Play) | |||
end. | |||
%% 执行控制策略 | |||
execute_control_strategy(HandCards, LastPlay, GameState) -> | |||
HandValue = analyze_hand_value(HandCards), | |||
case LastPlay of | |||
none -> | |||
choose_control_lead(HandValue); | |||
Play -> | |||
choose_control_follow(HandValue, Play) | |||
end. | |||
%% 执行爆发策略 | |||
execute_explosive_strategy(HandCards, LastPlay, GameState) -> | |||
HandValue = analyze_hand_value(HandCards), | |||
case LastPlay of | |||
none -> | |||
choose_explosive_lead(HandValue); | |||
Play -> | |||
choose_explosive_follow(HandValue, Play) | |||
end. | |||
%% 选择进攻性首出 | |||
choose_aggressive_lead(HandValue) -> | |||
case find_best_lead_combination(HandValue) of | |||
{ok, Move} -> Move; | |||
none -> find_single_card(HandValue) | |||
end. | |||
%% 选择进攻性跟牌 | |||
choose_aggressive_follow(HandValue, LastPlay) -> | |||
case find_minimum_greater_combination(HandValue, LastPlay) of | |||
{ok, Move} -> Move; | |||
none -> | |||
case should_use_bomb(HandValue, LastPlay) of | |||
true -> find_smallest_bomb(HandValue); | |||
false -> pass | |||
end | |||
end. | |||
%% 选择保守性首出 | |||
choose_conservative_lead(HandValue) -> | |||
case find_safe_lead(HandValue) of | |||
{ok, Move} -> Move; | |||
none -> find_smallest_single(HandValue) | |||
end. | |||
%% 选择保守性跟牌 | |||
choose_conservative_follow(HandValue, LastPlay) -> | |||
case find_safe_follow(HandValue, LastPlay) of | |||
{ok, Move} -> Move; | |||
none -> pass | |||
end. | |||
%% 选择控制性首出 | |||
choose_control_lead(HandValue) -> | |||
case find_control_combination(HandValue) of | |||
{ok, Move} -> Move; | |||
none -> find_tempo_play(HandValue) | |||
end. | |||
%% 选择控制性跟牌 | |||
choose_control_follow(HandValue, LastPlay) -> | |||
case should_maintain_control(HandValue, LastPlay) of | |||
true -> find_minimum_greater_combination(HandValue, LastPlay); | |||
false -> pass | |||
end. | |||
%% 选择爆发性首出 | |||
choose_explosive_lead(HandValue) -> | |||
case find_strongest_combination(HandValue) of | |||
{ok, Move} -> Move; | |||
none -> find_any_playable(HandValue) | |||
end. | |||
%% 选择爆发性跟牌 | |||
choose_explosive_follow(HandValue, LastPlay) -> | |||
case find_crushing_play(HandValue, LastPlay) of | |||
{ok, Move} -> Move; | |||
none -> pass | |||
end. | |||
%% 分析手牌价值 | |||
analyze_hand_value(Cards) -> | |||
GroupedCards = group_cards(Cards), | |||
#hand_value{ | |||
singles = find_singles(GroupedCards), | |||
pairs = find_pairs(GroupedCards), | |||
triples = find_triples(GroupedCards), | |||
sequences = find_sequences(GroupedCards), | |||
bombs = find_bombs(GroupedCards), | |||
rockets = find_rockets(GroupedCards) | |||
}. | |||
%% 计算出牌得分 | |||
calculate_move_score(Move, Context, LastPlay) -> | |||
BaseScore = calculate_base_score(Move), | |||
TempoScore = calculate_tempo_score(Move, Context), | |||
ControlScore = calculate_control_score(Move, Context), | |||
EfficiencyScore = calculate_efficiency_score(Move, Context), | |||
WeightedScore = BaseScore * 0.3 + | |||
TempoScore * 0.2 + | |||
ControlScore * 0.3 + | |||
EfficiencyScore * 0.2, | |||
adjust_score_for_context(WeightedScore, Move, Context, LastPlay). | |||
%% 辅助函数 | |||
group_cards(Cards) -> | |||
lists:foldl(fun(Card, Acc) -> | |||
{Value, _} = Card, | |||
maps:update_with(Value, fun(List) -> [Card|List] end, [Card], Acc) | |||
end, maps:new(), Cards). | |||
find_singles(GroupedCards) -> | |||
maps:fold(fun(Value, [Card], Acc) -> | |||
[{Value, Card}|Acc]; | |||
(_, _, Acc) -> Acc | |||
end, [], GroupedCards). | |||
find_pairs(GroupedCards) -> | |||
maps:fold(fun(Value, Cards, Acc) -> | |||
case length(Cards) >= 2 of | |||
true -> [{Value, lists:sublist(Cards, 2)}|Acc]; | |||
false -> Acc | |||
end | |||
end, [], GroupedCards). | |||
find_triples(GroupedCards) -> | |||
maps:fold(fun(Value, Cards, Acc) -> | |||
case length(Cards) >= 3 of | |||
true -> [{Value, lists:sublist(Cards, 3)}|Acc]; | |||
false -> Acc | |||
end | |||
end, [], GroupedCards). | |||
find_sequences(GroupedCards) -> | |||
Values = lists:sort(maps:keys(GroupedCards)), | |||
find_consecutive_sequences(Values, GroupedCards, 5). | |||
find_bombs(GroupedCards) -> | |||
maps:fold(fun(Value, Cards, Acc) -> | |||
case length(Cards) >= 4 of | |||
true -> [{Value, Cards}|Acc]; | |||
false -> Acc | |||
end | |||
end, [], GroupedCards). | |||
find_rockets(GroupedCards) -> | |||
case {maps:get(?CARD_JOKER_SMALL, GroupedCards, []), | |||
maps:get(?CARD_JOKER_BIG, GroupedCards, [])} of | |||
{[Small], [Big]} -> [{?CARD_JOKER_BIG, [Small, Big]}]; | |||
_ -> [] | |||
end. | |||
find_consecutive_sequences(Values, GroupedCards, MinLength) -> | |||
find_sequences_of_length(Values, GroupedCards, MinLength, []). | |||
find_sequences_of_length([], _, _, Acc) -> Acc; | |||
find_sequences_of_length([V|Rest], GroupedCards, MinLength, Acc) -> | |||
case find_sequence_starting_at(V, Rest, GroupedCards, MinLength) of | |||
{ok, Sequence} -> | |||
find_sequences_of_length(Rest, GroupedCards, MinLength, | |||
[Sequence|Acc]); | |||
false -> | |||
find_sequences_of_length(Rest, GroupedCards, MinLength, Acc) | |||
end. | |||
find_sequence_starting_at(Start, Rest, GroupedCards, MinLength) -> | |||
case collect_sequence(Start, Rest, GroupedCards, []) of | |||
Seq when length(Seq) >= MinLength -> | |||
{ok, {Start, Seq}}; | |||
_ -> | |||
false | |||
end. | |||
collect_sequence(Current, [Next|Rest], GroupedCards, Acc) when Next =:= Current + 1 -> | |||
case maps:get(Current, GroupedCards) of | |||
Cards when length(Cards) >= 1 -> | |||
collect_sequence(Next, Rest, GroupedCards, [hd(Cards)|Acc]); | |||
_ -> | |||
lists:reverse(Acc) | |||
end; | |||
collect_sequence(Current, _, GroupedCards, Acc) -> | |||
case maps:get(Current, GroupedCards) of | |||
Cards when length(Cards) >= 1 -> | |||
lists:reverse([hd(Cards)|Acc]); | |||
_ -> | |||
lists:reverse(Acc) | |||
end. | |||
calculate_base_score({Type, Value, _Cards}) -> | |||
BaseScore = Value * 10, | |||
TypeMultiplier = case Type of | |||
?CARD_TYPE_ROCKET -> 100; | |||
?CARD_TYPE_BOMB -> 80; | |||
?CARD_TYPE_STRAIGHT -> 40; | |||
?CARD_TYPE_STRAIGHT_PAIR -> 35; | |||
?CARD_TYPE_PLANE -> 30; | |||
?CARD_TYPE_THREE_TWO -> 25; | |||
?CARD_TYPE_THREE_ONE -> 20; | |||
?CARD_TYPE_THREE -> 15; | |||
?CARD_TYPE_PAIR -> 10; | |||
?CARD_TYPE_SINGLE -> 5 | |||
end, | |||
BaseScore * TypeMultiplier / 100. | |||
calculate_tempo_score(Move, Context) -> | |||
case Context#strategy_context.game_stage of | |||
early_game -> calculate_early_tempo(Move, Context); | |||
mid_game -> calculate_mid_tempo(Move, Context); | |||
end_game -> calculate_end_tempo(Move, Context) | |||
end. | |||
calculate_control_score(Move, Context) -> | |||
BaseControl = case Move of | |||
{Type, _, _} when Type =:= ?CARD_TYPE_BOMB; | |||
Type =:= ?CARD_TYPE_ROCKET -> 1.0; | |||
{_, Value, _} when Value >= ?CARD_2 -> 0.8; | |||
_ -> 0.5 | |||
end, | |||
BaseControl * Context#strategy_context.control_level. | |||
calculate_efficiency_score(Move, Context) -> | |||
{Type, _, Cards} = Move, | |||
CardsUsed = length(Cards), | |||
RemainingCards = Context#strategy_context.cards_remaining - CardsUsed, | |||
Efficiency = CardsUsed / max(1, Context#strategy_context.cards_remaining), | |||
Efficiency * (1 + (20 - RemainingCards) / 20). | |||
adjust_score_for_context(Score, Move, Context, LastPlay) -> | |||
case {Context#strategy_context.role, LastPlay} of | |||
{dizhu, none} -> Score * 1.2; | |||
{dizhu, _} -> Score * 1.1; | |||
{nongmin, _} -> Score | |||
end. | |||
calculate_early_tempo(Move, Context) -> | |||
case Move of | |||
{Type, _, _} when Type =:= ?CARD_TYPE_SINGLE; | |||
Type =:= ?CARD_TYPE_PAIR -> | |||
0.7; | |||
{Type, _, _} when Type =:= ?CARD_TYPE_STRAIGHT; | |||
Type =:= ?CARD_TYPE_STRAIGHT_PAIR -> | |||
0.9; | |||
_ -> | |||
0.5 | |||
end. | |||
calculate_mid_tempo(Move, Context) -> | |||
case Move of | |||
{Type, _, _} when Type =:= ?CARD_TYPE_THREE_ONE; | |||
Type =:= ?CARD_TYPE_THREE_TWO -> | |||
0.8; | |||
{Type, _, _} when Type =:= ?CARD_TYPE_PLANE -> | |||
0.9; | |||
_ -> | |||
0.6 | |||
end. | |||
calculate_end_tempo(Move, Context) -> | |||
case Move of | |||
{Type, _, _} when Type =:= ?CARD_TYPE_BOMB; | |||
Type =:= ?CARD_TYPE_ROCKET -> | |||
1.0; | |||
{_, Value, _} when Value >= ?CARD_2 -> | |||
0.9; | |||
_ -> | |||
0.7 | |||
end. |
@ -0,0 +1,44 @@ | |||
-module(doudizhu_ai_sup). | |||
-behaviour(supervisor). | |||
-export([start_link/0]). | |||
-export([init/1]). | |||
start_link() -> | |||
supervisor:start_link({local, ?MODULE}, ?MODULE, []). | |||
init([]) -> | |||
SupFlags = #{ | |||
strategy => one_for_one, | |||
intensity => 10, | |||
period => 60 | |||
}, | |||
Children = [ | |||
#{ | |||
id => ml_engine, | |||
start => {ml_engine, start_link, []}, | |||
restart => permanent, | |||
shutdown => 5000, | |||
type => worker, | |||
modules => [ml_engine] | |||
}, | |||
#{ | |||
id => training_system, | |||
start => {training_system, start_link, []}, | |||
restart => permanent, | |||
shutdown => 5000, | |||
type => worker, | |||
modules => [training_system] | |||
}, | |||
#{ | |||
id => visualization, | |||
start => {visualization, start_link, []}, | |||
restart => permanent, | |||
shutdown => 5000, | |||
type => worker, | |||
modules => [visualization] | |||
} | |||
], | |||
{ok, {SupFlags, Children}}. |
@ -1,112 +0,0 @@ | |||
-module(matrix). | |||
-export([new/3, multiply/2, add/2, subtract/2, transpose/1, map/2]). | |||
-export([from_list/1, to_list/1, get/3, set/4, shape/1]). | |||
-record(matrix, { | |||
rows, | |||
cols, | |||
data | |||
}). | |||
new(Rows, Cols, InitFun) when is_integer(Rows), is_integer(Cols), Rows > 0, Cols > 0 -> | |||
Data = array:new(Rows * Cols, {default, 0.0}), | |||
Data2 = case is_function(InitFun) of | |||
true -> | |||
lists:foldl( | |||
fun(I, Acc) -> | |||
lists:foldl( | |||
fun(J, Acc2) -> | |||
array:set(I * Cols + J, InitFun(), Acc2) | |||
end, | |||
Acc, | |||
lists:seq(0, Cols-1) | |||
) | |||
end, | |||
Data, | |||
lists:seq(0, Rows-1) | |||
); | |||
false -> | |||
array:set_value(InitFun, Data) | |||
end, | |||
#matrix{rows = Rows, cols = Cols, data = Data2}. | |||
multiply(#matrix{rows = M, cols = N, data = Data1}, | |||
#matrix{rows = N, cols = P, data = Data2}) -> | |||
Result = array:new(M * P, {default, 0.0}), | |||
ResultData = lists:foldl( | |||
fun(I, Acc1) -> | |||
lists:foldl( | |||
fun(J, Acc2) -> | |||
Sum = lists:sum([ | |||
array:get(I * N + K, Data1) * array:get(K * P + J, Data2) | |||
|| K <- lists:seq(0, N-1) | |||
]), | |||
array:set(I * P + J, Sum, Acc2) | |||
end, | |||
Acc1, | |||
lists:seq(0, P-1) | |||
) | |||
end, | |||
Result, | |||
lists:seq(0, M-1) | |||
), | |||
#matrix{rows = M, cols = P, data = ResultData}. | |||
add(#matrix{rows = R, cols = C, data = Data1}, | |||
#matrix{rows = R, cols = C, data = Data2}) -> | |||
NewData = array:map( | |||
fun(I, V) -> V + array:get(I, Data2) end, | |||
Data1 | |||
), | |||
#matrix{rows = R, cols = C, data = NewData}. | |||
subtract(#matrix{rows = R, cols = C, data = Data1}, | |||
#matrix{rows = R, cols = C, data = Data2}) -> | |||
NewData = array:map( | |||
fun(I, V) -> V - array:get(I, Data2) end, | |||
Data1 | |||
), | |||
#matrix{rows = R, cols = C, data = NewData}. | |||
transpose(#matrix{rows = R, cols = C, data = Data}) -> | |||
NewData = array:new(R * C, {default, 0.0}), | |||
TransposedData = lists:foldl( | |||
fun(I, Acc1) -> | |||
lists:foldl( | |||
fun(J, Acc2) -> | |||
array:set(J * R + I, array:get(I * C + J, Data), Acc2) | |||
end, | |||
Acc1, | |||
lists:seq(0, C-1) | |||
) | |||
end, | |||
NewData, | |||
lists:seq(0, R-1) | |||
), | |||
#matrix{rows = C, cols = R, data = TransposedData}. | |||
map(Fun, #matrix{rows = R, cols = C, data = Data}) -> | |||
NewData = array:map(fun(_, V) -> Fun(V) end, Data), | |||
#matrix{rows = R, cols = C, data = NewData}. | |||
from_list(List) when is_list(List) -> | |||
Rows = length(List), | |||
Cols = length(hd(List)), | |||
Data = array:from_list(lists:flatten(List)), | |||
#matrix{rows = Rows, cols = Cols, data = Data}. | |||
to_list(#matrix{rows = R, cols = C, data = Data}) -> | |||
[ | |||
[array:get(I * C + J, Data) || J <- lists:seq(0, C-1)] | |||
|| I <- lists:seq(0, R-1) | |||
]. | |||
get(#matrix{cols = C, data = Data}, Row, Col) -> | |||
array:get(Row * C + Col, Data). | |||
set(#matrix{cols = C, data = Data} = M, Row, Col, Value) -> | |||
NewData = array:set(Row * C + Col, Value, Data), | |||
M#matrix{data = NewData}. | |||
shape(#matrix{rows = R, cols = C}) -> | |||
{R, C}. |
@ -1,157 +1,663 @@ | |||
-module(ml_engine). | |||
-behaviour(gen_server). | |||
-export([start_link/0, init/1, handle_call/3, handle_cast/2, handle_info/2, terminate/2, code_change/3]). | |||
-export([train/2, predict/2, update_model/3, get_model_stats/1]). | |||
%% API exports | |||
-export([ | |||
start_link/0, | |||
train/2, | |||
predict/2, | |||
update_model/2, | |||
get_model_state/0, | |||
save_model/1, | |||
load_model/1, | |||
add_training_sample/1 | |||
]). | |||
%% gen_server callbacks | |||
-export([ | |||
init/1, | |||
handle_call/3, | |||
handle_cast/2, | |||
handle_info/2, | |||
terminate/2, | |||
code_change/3 | |||
]). | |||
-include("card_types.hrl"). | |||
-record(state, { | |||
models = #{}, % Map: ModelName -> ModelData | |||
training_data = #{}, % Map: ModelName -> TrainingData | |||
model_stats = #{}, % Map: ModelName -> Stats | |||
last_update = undefined | |||
model, % 当前模型 | |||
model_version = 1, % 模型版本 | |||
training_data = [], % 训练数据集 | |||
hyperparameters = #{}, % 超参数 | |||
feature_config = #{}, % 特征配置 | |||
last_update, % 最后更新时间 | |||
performance_metrics = [] % 性能指标历史 | |||
}). | |||
-record(model_data, { | |||
weights = #{}, % 模型权重 | |||
features = [], % 特征列表 | |||
learning_rate = 0.01, % 学习率 | |||
iterations = 0, % 训练迭代次数 | |||
accuracy = 0.0 % 模型准确率 | |||
-record(model, { | |||
weights = #{}, % 模型权重 | |||
biases = #{}, % 偏置项 | |||
layers = [], % 网络层配置 | |||
activation_functions = #{}, % 激活函数配置 | |||
normalization_params = #{}, % 归一化参数 | |||
feature_importance = #{}, % 特征重要性 | |||
last_train_error = 0.0 % 最后一次训练误差 | |||
}). | |||
%% API | |||
%% API 函数 | |||
start_link() -> | |||
gen_server:start_link({local, ?MODULE}, ?MODULE, [], []). | |||
train(ModelName, TrainingData) -> | |||
gen_server:call(?MODULE, {train, ModelName, TrainingData}). | |||
train(Data, Options) -> | |||
gen_server:call(?MODULE, {train, Data, Options}, infinity). | |||
predict(Features, Options) -> | |||
gen_server:call(?MODULE, {predict, Features, Options}). | |||
predict(ModelName, Features) -> | |||
gen_server:call(?MODULE, {predict, ModelName, Features}). | |||
update_model(NewModel, Options) -> | |||
gen_server:cast(?MODULE, {update_model, NewModel, Options}). | |||
update_model(ModelName, NewData, Reward) -> | |||
gen_server:cast(?MODULE, {update_model, ModelName, NewData, Reward}). | |||
get_model_state() -> | |||
gen_server:call(?MODULE, get_model_state). | |||
get_model_stats(ModelName) -> | |||
gen_server:call(?MODULE, {get_stats, ModelName}). | |||
save_model(Filename) -> | |||
gen_server:call(?MODULE, {save_model, Filename}). | |||
%% Callbacks | |||
load_model(Filename) -> | |||
gen_server:call(?MODULE, {load_model, Filename}). | |||
add_training_sample(Sample) -> | |||
gen_server:cast(?MODULE, {add_training_sample, Sample}). | |||
%% Callback 函数 | |||
init([]) -> | |||
% 初始化各种策略模型 | |||
Models = initialize_models(), | |||
{ok, #state{models = Models, last_update = os:timestamp()}}. | |||
handle_call({train, ModelName, TrainingData}, _From, State) -> | |||
{NewModel, Stats} = train_model(ModelName, TrainingData, State), | |||
NewModels = maps:put(ModelName, NewModel, State#state.models), | |||
NewStats = maps:put(ModelName, Stats, State#state.model_stats), | |||
{reply, {ok, Stats}, State#state{models = NewModels, model_stats = NewStats}}; | |||
handle_call({predict, ModelName, Features}, _From, State) -> | |||
case maps:get(ModelName, State#state.models, undefined) of | |||
undefined -> | |||
{reply, {error, model_not_found}, State}; | |||
Model -> | |||
Prediction = make_prediction(Model, Features), | |||
{reply, {ok, Prediction}, State} | |||
{ok, #state{ | |||
model = initialize_model(), | |||
last_update = os:timestamp(), | |||
hyperparameters = default_hyperparameters(), | |||
feature_config = default_feature_config() | |||
}}. | |||
handle_call({train, Data, Options}, _From, State) -> | |||
{Result, NewState} = do_train(Data, Options, State), | |||
{reply, Result, NewState}; | |||
handle_call({predict, Features, Options}, _From, State) -> | |||
{Prediction, NewState} = do_predict(Features, Options, State), | |||
{reply, Prediction, NewState}; | |||
handle_call(get_model_state, _From, State) -> | |||
{reply, get_current_model_state(State), State}; | |||
handle_call({save_model, Filename}, _From, State) -> | |||
Result = do_save_model(State, Filename), | |||
{reply, Result, State}; | |||
handle_call({load_model, Filename}, _From, State) -> | |||
case do_load_model(Filename) of | |||
{ok, NewState} -> {reply, ok, NewState}; | |||
Error -> {reply, Error, State} | |||
end; | |||
handle_call({get_stats, ModelName}, _From, State) -> | |||
Stats = maps:get(ModelName, State#state.model_stats, undefined), | |||
{reply, {ok, Stats}, State}. | |||
handle_call(_Request, _From, State) -> | |||
{reply, {error, unknown_call}, State}. | |||
handle_cast({update_model, NewModel, Options}, State) -> | |||
{noreply, do_update_model(State, NewModel, Options)}; | |||
handle_cast({add_training_sample, Sample}, State) -> | |||
{noreply, do_add_training_sample(State, Sample)}; | |||
handle_cast({update_model, ModelName, NewData, Reward}, State) -> | |||
Model = maps:get(ModelName, State#state.models), | |||
UpdatedModel = update_model_weights(Model, NewData, Reward), | |||
NewModels = maps:put(ModelName, UpdatedModel, State#state.models), | |||
{noreply, State#state{models = NewModels}}. | |||
handle_cast(_Msg, State) -> | |||
{noreply, State}. | |||
handle_info(update_metrics, State) -> | |||
{noreply, update_performance_metrics(State)}; | |||
handle_info(_Info, State) -> | |||
{noreply, State}. | |||
terminate(_Reason, State) -> | |||
save_final_state(State), | |||
ok. | |||
code_change(_OldVsn, State, _Extra) -> | |||
{ok, State}. | |||
%% 内部函数 | |||
initialize_models() -> | |||
Models = #{ | |||
play_strategy => init_play_strategy_model(), | |||
card_combination => init_card_combination_model(), | |||
opponent_prediction => init_opponent_prediction_model(), | |||
game_state_evaluation => init_game_state_model() | |||
}, | |||
Models. | |||
init_play_strategy_model() -> | |||
#model_data{ | |||
features = [ | |||
remaining_cards, | |||
opponent_cards, | |||
current_position, | |||
game_stage, | |||
last_play_type, | |||
has_control | |||
] | |||
%% 初始化模型 | |||
initialize_model() -> | |||
#model{ | |||
weights = initialize_weights(), | |||
biases = initialize_biases(), | |||
layers = default_layer_configuration(), | |||
activation_functions = default_activation_functions(), | |||
normalization_params = initialize_normalization_params() | |||
}. | |||
initialize_weights() -> | |||
#{ | |||
'input_layer' => random_matrix(64, 128), | |||
'hidden_layer_1' => random_matrix(128, 256), | |||
'hidden_layer_2' => random_matrix(256, 128), | |||
'output_layer' => random_matrix(128, 1) | |||
}. | |||
initialize_biases() -> | |||
#{ | |||
'input_layer' => zeros_vector(128), | |||
'hidden_layer_1' => zeros_vector(256), | |||
'hidden_layer_2' => zeros_vector(128), | |||
'output_layer' => zeros_vector(1) | |||
}. | |||
init_card_combination_model() -> | |||
#model_data{ | |||
features = [ | |||
card_count, | |||
card_types, | |||
sequence_length, | |||
combination_value | |||
] | |||
%% 默认配置 | |||
default_hyperparameters() -> | |||
#{ | |||
learning_rate => 0.001, | |||
batch_size => 32, | |||
epochs => 100, | |||
momentum => 0.9, | |||
dropout_rate => 0.5, | |||
l2_regularization => 0.01, | |||
early_stopping_patience => 5 | |||
}. | |||
init_opponent_prediction_model() -> | |||
#model_data{ | |||
features = [ | |||
played_cards, | |||
remaining_unknown, | |||
player_position, | |||
playing_pattern | |||
] | |||
default_feature_config() -> | |||
#{ | |||
card_value_weight => 1.0, | |||
card_type_weight => 0.8, | |||
sequence_weight => 1.2, | |||
combo_weight => 1.5, | |||
position_weight => 0.7, | |||
timing_weight => 0.9 | |||
}. | |||
init_game_state_model() -> | |||
#model_data{ | |||
features = [ | |||
cards_played, | |||
cards_remaining, | |||
player_positions, | |||
game_control | |||
] | |||
default_layer_configuration() -> | |||
[ | |||
{input, 64}, | |||
{dense, 128, relu}, | |||
{dropout, 0.5}, | |||
{dense, 256, relu}, | |||
{dropout, 0.5}, | |||
{dense, 128, relu}, | |||
{dense, 1, sigmoid} | |||
]. | |||
default_activation_functions() -> | |||
#{ | |||
relu => fun(X) -> max(0, X) end, | |||
sigmoid => fun(X) -> 1 / (1 + math:exp(-X)) end, | |||
tanh => fun(X) -> math:tanh(X) end, | |||
softmax => fun softmax/1 | |||
}. | |||
train_model(ModelName, TrainingData, State) -> | |||
Model = maps:get(ModelName, State#state.models), | |||
{UpdatedModel, Stats} = case ModelName of | |||
play_strategy -> | |||
train_play_strategy(Model, TrainingData); | |||
card_combination -> | |||
train_card_combination(Model, TrainingData); | |||
opponent_prediction -> | |||
train_opponent_prediction(Model, TrainingData); | |||
game_state_evaluation -> | |||
train_game_state(Model, TrainingData) | |||
%% 训练相关函数 | |||
do_train(Data, Options, State) -> | |||
try | |||
% 数据预处理 | |||
ProcessedData = preprocess_data(Data, State), | |||
% 划分训练集和验证集 | |||
{TrainData, ValidData} = split_train_valid(ProcessedData), | |||
% 训练模型 | |||
{NewModel, TrainMetrics} = train_model(TrainData, ValidData, Options, State), | |||
% 更新状态 | |||
NewState = update_state_after_training(State, NewModel, TrainMetrics), | |||
{{ok, TrainMetrics}, NewState} | |||
catch | |||
Error:Reason -> | |||
{{error, {Error, Reason}}, State} | |||
end. | |||
%% 预测相关函数 | |||
do_predict(Features, Options, State) -> | |||
try | |||
% 特征预处理 | |||
ProcessedFeatures = preprocess_features(Features, State), | |||
% 执行预测 | |||
Prediction = forward_pass(ProcessedFeatures, State#state.model), | |||
% 后处理预测结果 | |||
ProcessedPrediction = postprocess_prediction(Prediction, Options), | |||
{ProcessedPrediction, State} | |||
catch | |||
Error:Reason -> | |||
{{error, {Error, Reason}}, State} | |||
end. | |||
%% 模型更新函数 | |||
do_update_model(State, NewModel, Options) -> | |||
ValidatedModel = validate_model(NewModel), | |||
State#state{ | |||
model = ValidatedModel, | |||
model_version = State#state.model_version + 1, | |||
last_update = os:timestamp() | |||
}. | |||
%% 添加训练样本 | |||
do_add_training_sample(State, Sample) -> | |||
ValidatedSample = validate_sample(Sample), | |||
NewTrainingData = [ValidatedSample | State#state.training_data], | |||
State#state{training_data = NewTrainingData}. | |||
%% 数据预处理函数 | |||
preprocess_data(Data, State) -> | |||
% 特征提取 | |||
Features = extract_features(Data), | |||
% 特征归一化 | |||
NormalizedFeatures = normalize_features(Features, State#state.model.normalization_params), | |||
% 特征选择 | |||
SelectedFeatures = select_features(NormalizedFeatures, State#state.feature_config), | |||
% 增强特征 | |||
AugmentedFeatures = augment_features(SelectedFeatures), | |||
AugmentedFeatures. | |||
%% 特征处理函数 | |||
preprocess_features(Features, State) -> | |||
% 特征归一化 | |||
NormalizedFeatures = normalize_features(Features, State#state.model.normalization_params), | |||
% 特征转换 | |||
TransformedFeatures = transform_features(NormalizedFeatures), | |||
TransformedFeatures. | |||
%% 模型训练函数 | |||
train_model(TrainData, ValidData, Options, State) -> | |||
InitialModel = State#state.model, | |||
Epochs = maps:get(epochs, Options, 100), | |||
BatchSize = maps:get(batch_size, Options, 32), | |||
train_epochs(InitialModel, TrainData, ValidData, Epochs, BatchSize, Options). | |||
train_epochs(Model, _, _, 0, _, _) -> | |||
{Model, []}; | |||
train_epochs(Model, TrainData, ValidData, Epochs, BatchSize, Options) -> | |||
% 创建批次 | |||
Batches = create_batches(TrainData, BatchSize), | |||
% 训练一个epoch | |||
{UpdatedModel, EpochMetrics} = train_epoch(Model, Batches, ValidData, Options), | |||
% 检查是否需要提前停止 | |||
case should_early_stop(EpochMetrics, Options) of | |||
true -> | |||
{UpdatedModel, EpochMetrics}; | |||
false -> | |||
train_epochs(UpdatedModel, TrainData, ValidData, Epochs-1, BatchSize, Options) | |||
end. | |||
train_epoch(Model, Batches, ValidData, Options) -> | |||
% 训练所有批次 | |||
{TrainedModel, BatchMetrics} = train_batches(Model, Batches, Options), | |||
% 在验证集上评估 | |||
ValidationMetrics = evaluate_model(TrainedModel, ValidData), | |||
% 合并指标 | |||
EpochMetrics = merge_metrics(BatchMetrics, ValidationMetrics), | |||
{TrainedModel, EpochMetrics}. | |||
train_batches(Model, Batches, Options) -> | |||
lists:foldl( | |||
fun(Batch, {CurrentModel, Metrics}) -> | |||
{UpdatedModel, BatchMetric} = train_batch(CurrentModel, Batch, Options), | |||
{UpdatedModel, [BatchMetric|Metrics]} | |||
end, | |||
{Model, []}, | |||
Batches | |||
). | |||
train_batch(Model, Batch, Options) -> | |||
% 前向传播 | |||
{Predictions, CacheData} = forward_pass_with_cache(Model, Batch), | |||
% 计算损失 | |||
{Loss, LossGrad} = calculate_loss(Predictions, Batch, Options), | |||
% 反向传播 | |||
Gradients = backward_pass(LossGrad, CacheData, Model), | |||
% 更新模型参数 | |||
UpdatedModel = update_model_parameters(Model, Gradients, Options), | |||
% 返回更新后的模型和训练指标 | |||
{UpdatedModel, #{loss => Loss}}. | |||
%% 模型评估函数 | |||
evaluate_model(Model, Data) -> | |||
% 前向传播 | |||
Predictions = forward_pass(Model, Data), | |||
% 计算各种评估指标 | |||
#{ | |||
accuracy => calculate_accuracy(Predictions, Data), | |||
precision => calculate_precision(Predictions, Data), | |||
recall => calculate_recall(Predictions, Data), | |||
f1_score => calculate_f1_score(Predictions, Data) | |||
}. | |||
%% 工具函数 | |||
random_matrix(Rows, Cols) -> | |||
[ | |||
[rand:normal() / math:sqrt(Rows) || _ <- lists:seq(1, Cols)] | |||
|| _ <- lists:seq(1, Rows) | |||
]. | |||
zeros_vector(Size) -> | |||
[0.0 || _ <- lists:seq(1, Size)]. | |||
softmax(X) -> | |||
Exp = [math:exp(Xi) || Xi <- X], | |||
Sum = lists:sum(Exp), | |||
[E / Sum || E <- Exp]. | |||
create_batches(Data, BatchSize) -> | |||
create_batches(Data, BatchSize, []). | |||
create_batches([], _, Acc) -> | |||
lists:reverse(Acc); | |||
create_batches(Data, BatchSize, Acc) -> | |||
{Batch, Rest} = case length(Data) of | |||
N when N > BatchSize -> | |||
lists:split(BatchSize, Data); | |||
_ -> | |||
{Data, []} | |||
end, | |||
{UpdatedModel, Stats}. | |||
make_prediction(Model, Features) -> | |||
% 使用模型权重和特征进行预测 | |||
Weights = Model#model_data.weights, | |||
calculate_prediction(Features, Weights). | |||
update_model_weights(Model, NewData, Reward) -> | |||
% 使用强化学习更新模型权重 | |||
CurrentWeights = Model#model_data.weights, | |||
LearningRate = Model#model_data.learning_rate, | |||
UpdatedWeights = apply_reinforcement_learning(CurrentWeights, NewData, Reward, LearningRate), | |||
Model#model_data{weights = UpdatedWeights, iterations = Model#model_data.iterations + 1}. | |||
calculate_prediction(Features, Weights) -> | |||
% 实现预测算法 | |||
create_batches(Rest, BatchSize, [Batch|Acc]). | |||
%% 保存和加载模型 | |||
do_save_model(State, Filename) -> | |||
ModelData = #{ | |||
model => State#state.model, | |||
version => State#state.model_version, | |||
hyperparameters => State#state.hyperparameters, | |||
feature_config => State#state.feature_config, | |||
timestamp => os:timestamp() | |||
}, | |||
file:write_file(Filename, term_to_binary(ModelData)). | |||
do_load_model(Filename) -> | |||
case file:read_file(Filename) of | |||
{ok, Binary} -> | |||
try | |||
ModelData = binary_to_term(Binary), | |||
{ok, create_state_from_model_data(ModelData)} | |||
catch | |||
_:_ -> {error, invalid_model_file} | |||
end; | |||
Error -> | |||
Error | |||
end. | |||
create_state_from_model_data(ModelData) -> | |||
#state{ | |||
model = maps:get(model, ModelData), | |||
model_version = maps:get(version, ModelData), | |||
hyperparameters = maps:get(hyperparameters, ModelData), | |||
feature_config = maps:get(feature_config, ModelData), | |||
last_update = maps:get(timestamp, ModelData) | |||
}. | |||
%% 性能指标更新 | |||
update_performance_metrics(State) -> | |||
NewMetrics = calculate_current_metrics(State), | |||
State#state{ | |||
performance_metrics = [NewMetrics | State#state.performance_metrics] | |||
}. | |||
calculate_current_metrics(State) -> | |||
Model = State#state.model, | |||
#{ | |||
loss => Model#model.last_train_error, | |||
timestamp => os:timestamp() | |||
}. | |||
%% 状态更新 | |||
update_state_after_training(State, NewModel, Metrics) -> | |||
State#state{ | |||
model = NewModel, | |||
model_version = State#state.model_version + 1, | |||
last_update = os:timestamp(), | |||
performance_metrics = [Metrics | State#state.performance_metrics] | |||
}. | |||
%% 验证模型(续) | |||
validate_model(Model) -> | |||
% 验证权重和偏置 | |||
ValidatedWeights = validate_weights(Model#model.weights), | |||
ValidatedBiases = validate_biases(Model#model.biases), | |||
% 验证网络层配置 | |||
ValidatedLayers = validate_layers(Model#model.layers), | |||
% 验证激活函数 | |||
ValidatedActivations = validate_activation_functions(Model#model.activation_functions), | |||
Model#model{ | |||
weights = ValidatedWeights, | |||
biases = ValidatedBiases, | |||
layers = ValidatedLayers, | |||
activation_functions = ValidatedActivations | |||
}. | |||
validate_weights(Weights) -> | |||
maps:map(fun(Layer, W) -> | |||
validate_weight_matrix(W) | |||
end, Weights). | |||
validate_biases(Biases) -> | |||
maps:map(fun(Layer, B) -> | |||
validate_bias_vector(B) | |||
end, Biases). | |||
validate_layers(Layers) -> | |||
lists:map(fun validate_layer/1, Layers). | |||
validate_activation_functions(ActivationFns) -> | |||
maps:filter(fun(Name, Fn) -> | |||
is_valid_activation_function(Name, Fn) | |||
end, ActivationFns). | |||
%% 前向传播相关函数 | |||
forward_pass(Model, Input) -> | |||
{Output, _Cache} = forward_pass_with_cache(Model, Input), | |||
Output. | |||
forward_pass_with_cache(Model, Input) -> | |||
InitialCache = #{input => Input}, | |||
lists:foldl( | |||
fun({Feature, Value}, Acc) -> | |||
Weight = maps:get(Feature, Weights, 0), | |||
Acc + (Value * Weight) | |||
fun(Layer, {CurrentInput, Cache}) -> | |||
{Output, LayerCache} = forward_layer(Layer, CurrentInput, Model), | |||
{Output, Cache#{get_layer_name(Layer) => LayerCache}} | |||
end, | |||
{Input, InitialCache}, | |||
Model#model.layers | |||
). | |||
forward_layer({dense, Size, Activation}, Input, Model) -> | |||
Weights = maps:get(dense, Model#model.weights), | |||
Bias = maps:get(dense, Model#model.biases), | |||
% 线性变换 | |||
Z = matrix_multiply(Input, Weights) + Bias, | |||
% 激活函数 | |||
ActivationFn = maps:get(Activation, Model#model.activation_functions), | |||
Output = ActivationFn(Z), | |||
{Output, #{pre_activation => Z, output => Output}}; | |||
forward_layer({dropout, Rate}, Input, Model) -> | |||
case get_training_mode(Model) of | |||
true -> | |||
Mask = generate_dropout_mask(Input, Rate), | |||
Output = element_wise_multiply(Input, Mask), | |||
{Output, #{mask => Mask}}; | |||
false -> | |||
{Input, #{}} | |||
end. | |||
%% 反向传播相关函数 | |||
backward_pass(LossGrad, Cache, Model) -> | |||
{_, Gradients} = lists:foldr( | |||
fun(Layer, {CurrentGrad, LayerGrads}) -> | |||
LayerCache = maps:get(get_layer_name(Layer), Cache), | |||
{NextGrad, LayerGrad} = backward_layer(Layer, CurrentGrad, LayerCache, Model), | |||
{NextGrad, [LayerGrad | LayerGrads]} | |||
end, | |||
0, | |||
Features | |||
). | |||
{LossGrad, []}, | |||
Model#model.layers | |||
), | |||
consolidate_gradients(Gradients). | |||
backward_layer({dense, Size, Activation}, Grad, Cache, Model) -> | |||
% 获取激活函数导数 | |||
ActivationGrad = get_activation_gradient(Activation), | |||
% 计算激活函数的梯度 | |||
PreAct = maps:get(pre_activation, Cache), | |||
DZ = element_wise_multiply(Grad, ActivationGrad(PreAct)), | |||
% 计算权重和偏置的梯度 | |||
Input = maps:get(input, Cache), | |||
WeightGrad = matrix_multiply(transpose(Input), DZ), | |||
BiasGrad = sum_columns(DZ), | |||
% 计算输入的梯度 | |||
Weights = maps:get(dense, Model#model.weights), | |||
InputGrad = matrix_multiply(DZ, transpose(Weights)), | |||
{InputGrad, #{weights => WeightGrad, bias => BiasGrad}}; | |||
backward_layer({dropout, Rate}, Grad, Cache, _Model) -> | |||
Mask = maps:get(mask, Cache), | |||
{element_wise_multiply(Grad, Mask), #{}}. | |||
%% 损失函数相关 | |||
calculate_loss(Predictions, Targets, Options) -> | |||
LossType = maps:get(loss_type, Options, cross_entropy), | |||
calculate_loss_by_type(LossType, Predictions, Targets). | |||
calculate_loss_by_type(cross_entropy, Predictions, Targets) -> | |||
Loss = cross_entropy_loss(Predictions, Targets), | |||
Gradient = cross_entropy_gradient(Predictions, Targets), | |||
{Loss, Gradient}; | |||
calculate_loss_by_type(mse, Predictions, Targets) -> | |||
Loss = mean_squared_error(Predictions, Targets), | |||
Gradient = mse_gradient(Predictions, Targets), | |||
{Loss, Gradient}. | |||
%% 优化器相关函数 | |||
update_model_parameters(Model, Gradients, Options) -> | |||
Optimizer = maps:get(optimizer, Options, adam), | |||
LearningRate = maps:get(learning_rate, Options, 0.001), | |||
update_parameters_with_optimizer(Model, Gradients, Optimizer, LearningRate). | |||
update_parameters_with_optimizer(Model, Gradients, adam, LearningRate) -> | |||
% Adam优化器实现 | |||
Beta1 = 0.9, | |||
Beta2 = 0.999, | |||
Epsilon = 1.0e-8, | |||
% 更新动量 | |||
{NewWeights, NewMomentum} = update_adam_parameters( | |||
Model#model.weights, | |||
maps:get(weights, Gradients), | |||
maps:get(momentum, Model, #{}), | |||
LearningRate, | |||
Beta1, | |||
Beta2, | |||
Epsilon | |||
), | |||
Model#model{ | |||
weights = NewWeights, | |||
momentum = NewMomentum | |||
}; | |||
update_parameters_with_optimizer(Model, Gradients, sgd, LearningRate) -> | |||
% SGD优化器实现 | |||
NewWeights = update_sgd_parameters( | |||
Model#model.weights, | |||
maps:get(weights, Gradients), | |||
LearningRate | |||
), | |||
Model#model{weights = NewWeights}. | |||
%% 特征工程相关函数 | |||
extract_features(Data) -> | |||
lists:map(fun extract_sample_features/1, Data). | |||
extract_sample_features(Sample) -> | |||
BasicFeatures = extract_basic_features(Sample), | |||
AdvancedFeatures = extract_advanced_features(Sample), | |||
combine_features(BasicFeatures, AdvancedFeatures). | |||
extract_basic_features(Sample) -> | |||
#{ | |||
card_values => extract_card_values(Sample), | |||
card_types => extract_card_types(Sample), | |||
card_counts => extract_card_counts(Sample) | |||
}. | |||
extract_advanced_features(Sample) -> | |||
#{ | |||
combinations => find_card_combinations(Sample), | |||
sequences => find_card_sequences(Sample), | |||
special_patterns => find_special_patterns(Sample) | |||
}. | |||
%% 矩阵操作辅助函数 | |||
matrix_multiply(A, B) -> | |||
% 矩阵乘法实现 | |||
case {matrix_dimensions(A), matrix_dimensions(B)} of | |||
{{RowsA, ColsA}, {RowsB, ColsB}} when ColsA =:= RowsB -> | |||
do_matrix_multiply(A, B, RowsA, ColsB); | |||
_ -> | |||
error(matrix_dimension_mismatch) | |||
end. | |||
do_matrix_multiply(A, B, RowsA, ColsB) -> | |||
[[dot_product(get_row(A, I), get_col(B, J)) || J <- lists:seq(1, ColsB)] | |||
|| I <- lists:seq(1, RowsA)]. | |||
dot_product(Vec1, Vec2) -> | |||
lists:sum([X * Y || {X, Y} <- lists:zip(Vec1, Vec2)]). | |||
transpose(Matrix) -> | |||
case Matrix of | |||
[] -> []; | |||
[[]|_] -> []; | |||
_ -> | |||
[get_col(Matrix, I) || I <- lists:seq(1, length(hd(Matrix)))] | |||
end. | |||
%% 工具函数 | |||
get_layer_name({Type, Size, _}) -> | |||
atom_to_list(Type) ++ "_" ++ integer_to_list(Size); | |||
get_layer_name({Type, Rate}) -> | |||
atom_to_list(Type) ++ "_" ++ float_to_list(Rate). | |||
generate_dropout_mask(Input, Rate) -> | |||
Size = matrix_dimensions(Input), | |||
[[case rand:uniform() < Rate of true -> 0.0; false -> 1.0 end | |||
|| _ <- lists:seq(1, Size)] | |||
|| _ <- lists:seq(1, Size)]. | |||
element_wise_multiply(A, B) -> | |||
[[X * Y || {X, Y} <- lists:zip(RowA, RowB)] | |||
|| {RowA, RowB} <- lists:zip(A, B)]. | |||
sum_columns(Matrix) -> | |||
lists:foldl( | |||
fun(Row, Acc) -> | |||
[X + Y || {X, Y} <- lists:zip(Row, Acc)] | |||
end, | |||
lists:duplicate(length(hd(Matrix)), 0.0), | |||
Matrix | |||
). | |||
matrix_dimensions([]) -> {0, 0}; | |||
matrix_dimensions([[]|_]) -> {0, 0}; | |||
matrix_dimensions(Matrix) -> | |||
{length(Matrix), length(hd(Matrix))}. | |||
get_row(Matrix, I) -> | |||
lists:nth(I, Matrix). | |||
get_col(Matrix, J) -> | |||
[lists:nth(J, Row) || Row <- Matrix]. | |||
%% 初始化和保存最终状态 | |||
save_final_state(State) -> | |||
Filename = "ml_model_" ++ format_timestamp() ++ ".state", | |||
do_save_model(State, Filename). | |||
format_timestamp() -> | |||
{{Year, Month, Day}, {Hour, Minute, Second}} = calendar:universal_time(), | |||
lists:flatten(io_lib:format("~4..0w~2..0w~2..0w_~2..0w~2..0w~2..0w", | |||
[Year, Month, Day, Hour, Minute, Second])). |
@ -1,62 +0,0 @@ | |||
-module(opponent_modeling). | |||
-export([create_model/0, update_model/2, analyze_opponent/2, predict_play/2]). | |||
-record(opponent_model, { | |||
play_patterns = #{}, % 出牌模式统计 | |||
card_preferences = #{}, % 牌型偏好 | |||
risk_profile = 0.5, % 风险偏好 | |||
skill_rating = 500, % 技能评分 | |||
play_history = [] % 历史出牌记录 | |||
}). | |||
create_model() -> | |||
#opponent_model{}. | |||
update_model(Model, GamePlay) -> | |||
% 更新出牌模式统计 | |||
NewPatterns = update_play_patterns(Model#opponent_model.play_patterns, GamePlay), | |||
% 更新牌型偏好 | |||
NewPreferences = update_card_preferences(Model#opponent_model.card_preferences, GamePlay), | |||
% 更新风险偏好 | |||
NewRiskProfile = calculate_risk_profile(Model#opponent_model.risk_profile, GamePlay), | |||
% 更新技能评分 | |||
NewSkillRating = update_skill_rating(Model#opponent_model.skill_rating, GamePlay), | |||
% 更新历史记录 | |||
NewHistory = [GamePlay | Model#opponent_model.play_history], | |||
Model#opponent_model{ | |||
play_patterns = NewPatterns, | |||
card_preferences = NewPreferences, | |||
risk_profile = NewRiskProfile, | |||
skill_rating = NewSkillRating, | |||
play_history = lists:sublist(NewHistory, 100) % 保留最近100次出牌记录 | |||
}. | |||
analyze_opponent(Model, GameState) -> | |||
#{ | |||
style => determine_play_style(Model), | |||
strength => calculate_opponent_strength(Model), | |||
predictability => calculate_predictability(Model), | |||
weakness => identify_weaknesses(Model) | |||
}. | |||
predict_play(Model, GameState) -> | |||
% 基于历史模式预测 | |||
HistoryBasedPrediction = predict_from_history(Model, GameState), | |||
% 基于牌型偏好预测 | |||
PreferenceBasedPrediction = predict_from_preferences(Model, GameState), | |||
% 基于风险偏好预测 | |||
RiskBasedPrediction = predict_from_risk_profile(Model, GameState), | |||
% 综合预测结果 | |||
combine_predictions([ | |||
{HistoryBasedPrediction, 0.4}, | |||
{PreferenceBasedPrediction, 0.3}, | |||
{RiskBasedPrediction, 0.3} | |||
]). |
@ -1,55 +0,0 @@ | |||
-module(parallel_compute). | |||
-behaviour(gen_server). | |||
-export([start_link/0, init/1, handle_call/3, handle_cast/2, handle_info/2, terminate/2, code_change/3]). | |||
-export([parallel_predict/2, batch_process/2]). | |||
-record(state, { | |||
worker_pool = [], % 工作进程池 | |||
job_queue = [], % 任务队列 | |||
results = #{}, % 结果集 | |||
pool_size = 4 % 默认工作进程数 | |||
}). | |||
%% API | |||
start_link() -> | |||
gen_server:start_link({local, ?MODULE}, ?MODULE, [], []). | |||
parallel_predict(Inputs, Model) -> | |||
gen_server:call(?MODULE, {parallel_predict, Inputs, Model}). | |||
batch_process(BatchData, ProcessFun) -> | |||
gen_server:call(?MODULE, {batch_process, BatchData, ProcessFun}). | |||
%% 内部函数 | |||
initialize_worker_pool(PoolSize) -> | |||
[spawn_worker() || _ <- lists:seq(1, PoolSize)]. | |||
spawn_worker() -> | |||
spawn_link(fun() -> worker_loop() end). | |||
worker_loop() -> | |||
receive | |||
{process, Data, From} -> | |||
Result = process_data(Data), | |||
From ! {result, self(), Result}, | |||
worker_loop(); | |||
stop -> | |||
ok | |||
end. | |||
process_data({predict, Input, Model}) -> | |||
deep_learning:predict(Model, Input); | |||
process_data({custom, Fun, Data}) -> | |||
Fun(Data). | |||
distribute_work(Workers, Jobs) -> | |||
distribute_work(Workers, Jobs, #{}). | |||
distribute_work(_, [], Results) -> | |||
Results; | |||
distribute_work(Workers, [Job|Jobs], Results) -> | |||
[Worker|RestWorkers] = Workers, | |||
Worker ! {process, Job, self()}, | |||
distribute_work(RestWorkers ++ [Worker], Jobs, Results). |
@ -1,37 +0,0 @@ | |||
-module(test_suite). | |||
-export([run_full_test/0, validate_ai_performance/1]). | |||
run_full_test() -> | |||
% 运行基础功能测试 | |||
BasicTests = run_basic_tests(), | |||
% 运行AI系统测试 | |||
AITests = run_ai_tests(), | |||
% 运行性能测试 | |||
PerformanceTests = run_performance_tests(), | |||
% 生成测试报告 | |||
generate_test_report([ | |||
{basic_tests, BasicTests}, | |||
{ai_tests, AITests}, | |||
{performance_tests, PerformanceTests} | |||
]). | |||
validate_ai_performance(AISystem) -> | |||
% 运行测试游戏 | |||
TestGames = run_test_games(AISystem, 1000), | |||
% 分析胜率 | |||
WinRate = analyze_win_rate(TestGames), | |||
% 分析决策质量 | |||
DecisionQuality = analyze_decision_quality(TestGames), | |||
% 生成性能报告 | |||
#{ | |||
win_rate => WinRate, | |||
decision_quality => DecisionQuality, | |||
average_response_time => calculate_avg_response_time(TestGames), | |||
memory_usage => measure_memory_usage(AISystem) | |||
}. |
@ -1,25 +0,0 @@ | |||
-module(training_system). | |||
-export([start_training/0, process_game_data/1, update_models/1]). | |||
%% 开始训练过程 | |||
start_training() -> | |||
TrainingData = load_training_data(), | |||
Models = initialize_models(), | |||
train_models(Models, TrainingData, [ | |||
{epochs, 1000}, | |||
{batch_size, 32}, | |||
{learning_rate, 0.001} | |||
]). | |||
%% 处理游戏数据 | |||
process_game_data(GameRecord) -> | |||
Features = extract_features(GameRecord), | |||
Labels = extract_labels(GameRecord), | |||
update_training_dataset(Features, Labels). | |||
%% 更新模型 | |||
update_models(NewData) -> | |||
CurrentModels = get_current_models(), | |||
UpdatedModels = retrain_models(CurrentModels, NewData), | |||
validate_models(UpdatedModels), | |||
deploy_models(UpdatedModels). |