from algorithm.problem import Problem
from algorithm.mcts import MCTS
from algorithm.constraint import TupleInSetConstraint
from common.metric import Metrics, MeanMetric
from algorithm.solver import BalSolver
from common.parser import parser_csp
from algorithm.env import Env
from net.network import PolicyValueNet, GNNNet
from net.replaybuffer import ReplayBuffer
from torch_geometric.data import Data
from replaybuffer import ReplayBuffer
import torch
import ipdb
import numpy as np
from algorithm.mcts import MCTS
from tqdm._tqdm import trange

class Trainer():
    def __init__(self):
        self.model = None
        self.device = 'cuda:1'
        self.search_num = 10000
        self.simulate_num = 10
        self.epochs = 10
        self.batch_size = 128
        self.buffer_size = 10000

        self.policy_value_net = PolicyValueNet(device=self.device)
        self.buffer = ReplayBuffer(self.buffer_size)
        self.agent = MCTS(self.policy_value_net, c_puct=1, simulate_num=self.simulate_num)

        self.learn_rate = 2e-3
        self.best_win_ratio = 0.0
        self.output_frequency = 20
        self.metrics = Metrics()  # 初始化 Metrics
        self.metrics.create_metric('win_rate', MeanMetric())  # 创建 win_rate 的 Metric
        self.metrics.create_metric('states_len', MeanMetric())  # 创建 len(states) 的 Metric
   
    def get_search(self, file_path):
        problem = Problem(BalSolver())
        variables, constraints, domains = parser_csp(file_path)
        problem.addVariables(variables, domains)
        for (var1, var2), invalid_tuples in constraints.items():
            problem.addConstraint(TupleInSetConstraint(invalid_tuples), [var1, var2])
        domains, constraints, vconstraints = problem._getArgs()
        env = Env(domains, constraints, vconstraints)
        mcts = MCTS(env, self.max_iterations, self.policy_value_net)
        return mcts

    def collect_search_data(self, file_num):
        mcts = self.get_search(f'data/{file_num}.txt')
        data = mcts.run(self.agent)
        self.buffer.store(data['state'], data['action'], data['reward'], data['next_state'], data['done'])

    def policy_update(self):
        data = self.buffer.sample(self.batch_size)
        data = next(iter(data))

        for i in range(self.epochs):
            loss = self.policy_value_net.train_step(data, self.learn_rate)
        return loss

    def policy_evaluate(self):
        win = 0
        for i in range(100):
            for j in range(self.search_num):
                mcts = self.get_search(f'data/{i}.txt')
                end, _, _ = mcts.run()
        return win / self.search_num
    
    def run(self):
        for i in trange(self.epochs):
            for num in trange(1000):
                self.collect_search_data(num)
                if len(self.buffer) > self.batch_size:
                    loss = self.policy_update()
                    print('loss: ', loss)
                win_ratio = self.policy_evaluate()
                self.policy_value_net.save_model('./model/cur.model')
                if win_ratio > self.best_win_ratio:
                    print('New best policy!!!!!!!!')
                    self.best_win_ratio = win_ratio
                    self.policy_value_net.save_model('./model/best.model')
                
    def print_metrics(self):
        avg_win_rate = self.metrics.metrics['win_rate'].val()
        avg_states_len = self.metrics.metrics['states_len'].val()
        print(f'  Average Win Rate: {avg_win_rate:.4f}')
        print(f'  Average States Length: {avg_states_len:.4f}')
        self.metrics.reset()  # 重置 Metrics
        
if __name__ == '__main__':
    trainer = Trainer()
    trainer.run()
