import os
import sys
base_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
sys.path.append(base_path)

from typing import Optional, Any, Dict, List
import os
import torch
import math
import time
import json
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist
import numpy as np
from tqdm import tqdm
from dataloader.code.dataset import BlendableDataset, RLFullDataset
from dataloader.code.problem_loader import DDP_ProblemLoader
from dataloader.code.data_samplers import build_training_data_loader
from dataloader.code.tokenizer import ContinuousScalarTokenizer
import pprint
from utils.utils import set_seed, create_file_if_not_exist
from evaluate_test.evaluate_utils import evalute_one_episode, evalute_batch_episode
import matplotlib.pyplot as plt

class Evaluater:
    def __init__(self, args, eval_args, gato, envs, logger, datasets_train, datasets_test, dataset_weights, envs_problems):
        self.args = args
        self.eval_args = eval_args
        self.envs = envs
        self.logger = logger
        self.envs_problems = envs_problems

        # problem info
        self.problem_num = {}
        self.best_obj_ave = {}
        self.random_obj_ave = {}
        self.best_obj_std = {}
        self.random_obj_std = {}
    
        # set torchrun variables
        self.local_rank = int(os.environ["LOCAL_RANK"]) 
        self.global_rank = int(os.environ["RANK"])
        self.world_size = int(os.environ["WORLD_SIZE"])

        # data stuff
        (
            self.dataloader_train, 
            self.dataloader_val, 
            self.dataloader_test, 
            self.problemloader_dict,
            self.dataset_train,
            self.dataset_val,
            self.dataset_test,
        ) = self._prepare_dataloader(datasets_train, datasets_test, dataset_weights, envs_problems)        

        # wrap the model with DDP. this step will synch model across all the processes.
        self.gato = gato.to(self.local_rank).train()   
        self.gato = DDP(self.gato, device_ids=[self.local_rank])
        self.raw_gato = self.gato.module if hasattr(self.gato, "module") else self.gato
        self.optimizer = self.raw_gato.transformer.configure_optimizers()     
        self.cont_tokenizer = ContinuousScalarTokenizer(
            self.args.tokenizer_ver,
            self.args.num_continuous_bin, 
            self.args.discretize_mu, 
            self.args.discretize_M
        )
        
    def _prepare_dataloader(self, 
        datasets_train:List[RLFullDataset], 
        datasets_test:List[RLFullDataset], 
        dataset_weights:List[float],
        envs_problems:Dict,
    ):
        # split training set and evaluation set
        if datasets_train != []:
            train_datasets, val_datasets, test_datasets = [], [], datasets_test
            for dataset in datasets_train:
                dataset_train, dataset_val = dataset.split_dataset(self.args.split)
                val_datasets.append(dataset_val)
                train_datasets.append(dataset_train)
                
            # build BlendableDataset
            dataset_train = BlendableDataset(
                train_datasets, 
                dataset_weights,
                batch_size=self.eval_args.batch_size,
                log_data=self.eval_args.traindata_logger,
                with_dataset_info=True
            )
            dataset_val = BlendableDataset(
                val_datasets, 
                dataset_weights,
                batch_size=self.eval_args.eval_batch_size,
                with_dataset_info=True
            )

            if self.eval_args.batch_num == 0:
                self.eval_args.batch_num = math.ceil(len(dataset_train) / self.eval_args.batch_size)
                sample_num_per_training_epoch = None
            else:
                sample_num_per_training_epoch = self.eval_args.batch_num * self.eval_args.batch_size
            dataloader_train = build_training_data_loader(
                self.eval_args, 
                dataset_train, 
                epoch_total_samples=sample_num_per_training_epoch, 
                is_eval=False
            )

            if self.eval_args.eval_batch_num == 0:
                self.eval_args.eval_batch_num = math.ceil(len(dataset_val) / self.eval_args.eval_batch_size)
                sample_num_per_evaluation_epoch = None
            else:
                sample_num_per_evaluation_epoch = self.eval_args.eval_batch_num * self.eval_args.eval_batch_size
            dataloader_val = build_training_data_loader(
                self.eval_args, 
                dataset_val, 
                epoch_total_samples=sample_num_per_evaluation_epoch,
                is_eval=True
            )

            # The test dataset is the MDP episode dataset associated with dataset_problem
            if len(datasets_test) == len(datasets_train):
                dataset_test = BlendableDataset(
                    test_datasets, 
                    dataset_weights,
                    batch_size=self.eval_args.test_batch_size,
                    with_dataset_info=True
                )

                if self.eval_args.test_batch_num == 0:
                    self.eval_args.test_batch_num = math.ceil(len(dataset_test) / self.eval_args.test_batch_size)
                    sample_num_per_evaluation_epoch = None
                else:
                    sample_num_per_evaluation_epoch = self.eval_args.test_batch_num * self.eval_args.test_batch_size
                dataloader_test = build_training_data_loader(
                    self.eval_args, 
                    dataset_test, 
                    epoch_total_samples=sample_num_per_evaluation_epoch,
                    is_eval=True
                )
            else:
                dataset_test = dataloader_test = None
        else:
            dataloader_train = dataset_train = dataset_val = dataloader_val = dataset_test = dataloader_test = None

        # build problem loader
        problemloader_dict = {}
        for i, (name, dataset) in enumerate(envs_problems.items()):
            env = self.envs[i]
            assert env.env_name == name
            problemloader = DDP_ProblemLoader(dataset, env)
            problemloader_dict[name] = problemloader

            # get random_obj & best_obj
            rnd_obj_value = problemloader.random_obj_array[problemloader.random_obj_array != 0]
            best_obj_value = problemloader.best_obj_array[problemloader.best_obj_array != 0]
            obj_value = np.vstack([rnd_obj_value, best_obj_value])
            obj_value_tensor = torch.tensor(obj_value).to(self.local_rank)
            obj_value_gather_list = [torch.zeros_like(obj_value_tensor) for _ in range(self.world_size)]
            dist.barrier()
            dist.gather(
                obj_value_tensor,
                obj_value_gather_list if self.local_rank == 0 else None, 
                dst = 0
            )

            if self.local_rank == 0:
                obj_info_tensor = torch.cat(obj_value_gather_list, dim=1)
                problem_num = obj_info_tensor.shape[1]
                random_obj_ave = obj_info_tensor[0].mean().item()
                best_obj_ave = obj_info_tensor[1].mean().item()
                random_obj_std = obj_info_tensor[0].std().item()
                best_obj_std = obj_info_tensor[1].std().item()
                print(f'Ave obj value of [{problem_num}] samples in [{name}] is: \n\trandom policy: [{random_obj_ave}], std=[{random_obj_std}]\n\tbest polciy:   [{best_obj_ave}], std=[{best_obj_std}]')
                self.problem_num[name] = problem_num
                self.random_obj_ave[name] = random_obj_ave
                self.best_obj_ave[name] = best_obj_ave
                self.random_obj_std[name] = random_obj_std
                self.best_obj_std[name] = best_obj_std
                
                
        return dataloader_train, dataloader_val, dataloader_test, problemloader_dict, dataset_train, dataset_val, dataset_test

    def _eval_policy(self, sample_action=False, hard_action_constraint=False, regen_times=0):
        if not sample_action:
            assert regen_times == 0

        desc_sample = 'sample' if sample_action else 'greedy'
        desc_constraint = 'constraint' if hard_action_constraint else 'free'
        desc_vote = f'-vote{regen_times}' if regen_times != 0 else ''
        desc = f'{desc_sample}-{desc_constraint}{desc_vote}'
        episode_return = {k: {'AM':0, 'DB1':0} for k in self.args.eval_dataset_names}
        episode_obj = {k: 0 for k in self.args.eval_dataset_names}
        episode_obj_std = {k: 0 for k in self.args.eval_dataset_names}
        episode_safe_ratio = {k: 0 for k in self.args.eval_dataset_names}
        episode_time = {k: 0 for k in self.args.eval_dataset_names}

        for env_name, dataset_name, env in zip(self.args.eval_env_names, self.args.eval_dataset_names, self.envs):
            if self.args.use_ddp_env:                
                ave_return, ave_obj, obj_std, ave_safe_ratio, ave_time_used, episode = evalute_batch_episode(
                    args=self.args, 
                    model=self.gato, 
                    env=env,
                    problemloader=self.problemloader_dict[env_name],
                    cont_tokenizer=self.cont_tokenizer,
                    sample_action=sample_action,
                    hard_action_constraint=hard_action_constraint,
                    desc=f'[GPU{self.global_rank}]: Evaluating on {env_name} ({desc})',
                    device=self.local_rank        
                )
                
                # render greedy policy if necessary
                if self.logger is not None:
                    for i, epi in enumerate(episode):
                        self.logger[env_name].log_episode(
                            desc=desc,
                            is_eval=True,
                            episode=epi, 
                            epoch_num=0, 
                            episode_num=i,
                            time_used=-1,
                            seed=self.eval_args.seed
                        )
            else:
                problemloader = self.problemloader_dict[env_name]
                problemloader.reset()
                iters = self.args.problem_batch_num * self.args.problem_batch_size
                epi_return, epi_obj, epi_safe, epi_time = {'AM':[], 'DB1':[]}, [], [], []
                with tqdm(total=iters, desc=f'[GPU{self.global_rank}]: Evaluating on {env_name} ({desc})', position=self.local_rank) as pbar:
                    # problems 列表中每个元素都是目标环境的一个初始观测（包含随机生成的目标COP问题）
                    # 通过 env._recover 方法还原环境的观测从而设置被评估的COP问题
                    for i in range(iters):
                        problem_info, problem_obj = problemloader.get_problem(1)
                        problem_info = (None if problem_info[0] is None else problem_info[0][0], problem_info[1][0], problem_info[2][0])
                        problem_obj = (problem_obj[0].item(), problem_obj[1].item())
                        ep_ret, ep_obj, ep_safe, ep_len, ep_time, epi = evalute_one_episode(
                            args=self.args, 
                            model=self.gato, 
                            env=env, 
                            cont_tokenizer=self.cont_tokenizer,
                            sample_action=sample_action,
                            hard_action_constraint=hard_action_constraint,
                            regen_times=regen_times,
                            problem_info=problem_info,
                            problem_obj=problem_obj,
                            device=self.local_rank        
                        )
                        epi_safe.append(ep_safe)
                        epi_time.append(ep_time)
                        if ep_safe:
                            epi_return['AM'].append(ep_ret['AM'])
                            epi_return['DB1'].append(ep_ret['DB1'])
                            epi_obj.append(ep_obj)
                            
                        # render greedy policy if necessary
                        if self.logger is not None:    
                            self.logger[env_name].log_episode(
                                desc=desc,
                                is_eval=True,
                                episode=epi, 
                                epoch_num=0, 
                                episode_num=i,
                                time_used=ep_time,
                                seed=self.eval_args.seed
                            )

                        ave_safe_ratio = 0 if epi_safe == [] else np.mean(epi_safe)
                        ave_time_used = 0 if epi_time == [] else np.mean(epi_time)
                        ave_obj = 0 if epi_obj == [] else np.mean(epi_obj)
                        obj_std = 0 if epi_obj == [] else np.std(epi_obj)
                        ave_return_AM = 0 if epi_return['AM'] == [] else np.mean(epi_return['AM'])
                        ave_return_DB1 = 0 if epi_return['DB1'] == [] else np.mean(epi_return['DB1'])
                        ave_return = {'AM':ave_return_AM, 'DB1':ave_return_DB1}

                        info = {
                            'ret_AM': f'{ave_return_AM:.4f}',
                            'ret_DB1': f'{ave_return_DB1:.4f}',
                            'obj' : f'{ave_obj:.4f}',
                            'std' : f'{obj_std:.4f}',
                            'safe': f'{ave_safe_ratio:.2f}',
                            'time': f'{ave_time_used:.2f}',
                        }
                        pbar.set_postfix(info)
                        pbar.update()

            episode_return[dataset_name] = ave_return
            episode_obj[dataset_name] = ave_obj
            episode_obj_std[dataset_name] = obj_std
            episode_safe_ratio[dataset_name] = ave_safe_ratio
            episode_time[dataset_name] = ave_time_used
        
        return episode_return, episode_obj, episode_obj_std, episode_safe_ratio, episode_time

    def _run_batch(self, rl_task_input, batch_dataset_name, batch_raw_obs) -> float:
        with torch.set_grad_enabled(False):
            _, loss, _, _ = self.gato(rl_task_input, batch_dataset_name=batch_dataset_name, batch_raw_obs=batch_raw_obs)
        return loss.item()

    def _run_epoch(self, desc, batch_num, dataloader, is_train=False):
        epoch_losses = []
        desc = f'[GPU{self.global_rank}]: ' + desc
        total = math.ceil(batch_num/self.world_size)
        with tqdm(total=total, desc=desc, position=self.local_rank) as pbar:
            for batch in dataloader:
                rl_task_input, batch_dataset_name, batch_raw_obs = batch
                rl_task_input.to(device=self.local_rank)
                batch_dataset_name = [v[0] for v in batch_dataset_name]
                loss = self._run_batch(rl_task_input, batch_dataset_name, batch_raw_obs)
                epoch_losses.append(loss)

                pbar.set_postfix({
                    'loss':'{:.2f}'.format(loss), 
                    'ave loss (latest 20)': '{:.2f}'.format(np.array(epoch_losses[-20:]).mean())
                })
                pbar.update()

                # log train data if necessary
                if self.logger is not None and self.dataset_train is not None and self.eval_args.traindata_logger:
                    logged_data = self.dataset_train.get_logged_data()
                    for env_name in self.eval_args.eval_envs:
                        self.logger[env_name].log_data(logged_data, seed=self.eval_args.seed, is_train=False)

        return np.array(epoch_losses).mean()

    def evaluate(self):    
        # check train & eval loss
        self.gato.eval()
        self.raw_gato.transformer.same_length = False         # use fixed context length when rollout with mem (TransformerXL back bone)
        
        if self.eval_args.check_loss:
            with torch.inference_mode():
                self.dataloader_val.sampler.set_epoch(0, False)
                self.dataloader_train.sampler.set_epoch(0, True)
                train_loss = self._run_epoch('Calculating train loss', self.eval_args.batch_num, self.dataloader_train, is_train=True)
                eval_loss = self._run_epoch('Calculating eval loss', self.eval_args.eval_batch_num, self.dataloader_val, is_train=False)
                if self.dataloader_test is not None:
                    self.dataloader_test.sampler.set_epoch(0, False)
                    test_loss = self._run_epoch('Calculating test loss', self.eval_args.eval_test_num, self.dataloader_test, is_train=False)
        
        # policy evaluation setting
        eval_setting = {
            #'greedy_free': lambda: self._eval_policy(
            #        sample_action=False, hard_action_constraint=False),
            #'sample_free': lambda: self._eval_policy(
            #        sample_action=True, hard_action_constraint=False),
            #f'vote{self.eval_args.regen_times}_free': lambda: self._eval_policy(
            #        sample_action=True, hard_action_constraint=False, regen_times=self.eval_args.regen_times),
            'greedy_cst': lambda: self._eval_policy(
                    sample_action=False, hard_action_constraint=True),
            #'sample_cst': lambda: self._eval_policy(
            #        sample_action=True, hard_action_constraint=True),
            #f'vote{self.eval_args.regen_times}_cst': lambda: self._eval_policy(
            #        sample_action=True, hard_action_constraint=True, regen_times=self.eval_args.regen_times),
        }

        # Use this structure to store eval results
        result = {
            setting: {
                env_name: {'return_AM':0, 'return_DB1':0, 'problem_num':0, 'obj_best':0, 'obj_model':0, 'obj_random':0, 'safe':0, 'time':0} 
                for env_name in self.args.eval_env_names
            } for setting in eval_setting.keys()
        }

        # evaluate policy
        results = {setting: [] for setting in eval_setting.keys()}
        self.raw_gato.transformer.same_length = self.eval_args.use_mem         # use fixed context length when rollout with mem (TransformerXL backbone)
        with torch.inference_mode():
            set_seed(self.eval_args.seed)

            # evaluate policy
            for setting, eval_func in eval_setting.items():
                if self.local_rank == 0:
                    print(f'\n\n' + '-'*20 + f' Eval policy under setting [{setting}]; seed [{self.eval_args.seed}] ' + '-'*20)
                
                # eval on all dataset
                epi_return_dict, epi_obj_dict, episode_obj_std_dict, epi_safe_dict, epi_time_dict = eval_func()

                # gather result from all GPU
                result_AM, result_DB1, result_obj, result_obj_std, result_safe, result_time = [], [], [], [], [], []
                for k in sorted(self.args.eval_dataset_names):
                    result_AM.append(epi_return_dict[k]['AM'])
                    result_DB1.append(epi_return_dict[k]['DB1'])
                    result_obj.append(epi_obj_dict[k])
                    result_obj_std.append(episode_obj_std_dict[k])
                    result_safe.append(epi_safe_dict[k])
                    result_time.append(epi_time_dict[k])
                result_tensor = torch.tensor([result_AM, result_DB1, result_obj, result_obj_std, result_safe, result_time]).to(self.local_rank)
                gather_list = [torch.zeros_like(result_tensor) for _ in range(self.world_size)]
                dist.gather(
                    result_tensor,
                    gather_list if self.local_rank == 0 else None, 
                    dst=0
                )
                gather_result = torch.stack(gather_list).mean(axis=0)
                results[setting].append(gather_result)
    
        if self.local_rank == 0:
            # result summary
            result_file_path = f'{self.eval_args.ckpt_performance_path}/result.txt'
            create_file_if_not_exist(result_file_path)
            with open(f'{base_path}/visualize/eval/result/config.json', 'w') as f:
                f.write(json.dumps(vars(self.eval_args), indent=4))

            time.sleep(1)
            print('\n')
            for setting, res in results.items():
                res_summary = torch.stack(res).mean(axis=0) # 在所有 random seed 上取平均
                return_AM_summary = res_summary[0]
                return_DB1_summary = res_summary[1]
                obj_summary = res_summary[2]
                obj_std_summary = res_summary[3]
                safe_summary = res_summary[4]
                time_summary = res_summary[5]

                print('='*20+f' {setting} '+'='*20)
                with open(result_file_path, 'a') as f:
                    f.write('='*20+f' {setting} '+'='*20 + '\n')    
                for i, env_name in enumerate(sorted(self.args.eval_env_names)):            
                    result[setting][env_name]['return_AM'] = return_AM_summary[i].item()
                    result[setting][env_name]['return_DB1'] = return_DB1_summary[i].item()
                    result[setting][env_name]['obj_model'] = obj_summary[i].item()
                    result[setting][env_name]['obj_model_std'] = obj_std_summary[i].item()
                    result[setting][env_name]['obj_best'] = self.best_obj_ave[env_name]
                    result[setting][env_name]['obj_best_std'] = self.best_obj_std[env_name]
                    result[setting][env_name]['obj_random'] = self.random_obj_ave[env_name]
                    result[setting][env_name]['obj_random_std'] = self.random_obj_std[env_name]
                    result[setting][env_name]['safe'] = safe_summary[i].item()
                    result[setting][env_name]['time'] = time_summary[i].item()
                    result[setting][env_name]['problem_num'] = self.problem_num[env_name]

                    # 打印结果
                    pp = pprint.PrettyPrinter(indent=4)
                    print(env_name)
                    pp.pprint({k: round(v,5) for k,v in result[setting][env_name].items()})

                    # 结果保存为 txt 文件
                    with open(result_file_path, 'a') as f:
                        sys.stdout = f  # 将标准输出重定向到文件
                        pp = pprint.PrettyPrinter(indent=4)
                        print(env_name)
                        pp.pprint({k: round(v,5) for k,v in result[setting][env_name].items()})
                        sys.stdout = sys.__stdout__  # 恢复标准输出
                print()
                with open(result_file_path, 'a') as f:
                    f.write('\n') 


                info = {
                    'train_info':{
                        'train_iters': self.args.train_iters,
                        'batch_num': self.args.batch_num,
                        'batch_size': self.args.batch_size_vaild,
                        'start_grad_accum': self.args.start_grad_accum,
                        'end_grad_accum': self.args.end_grad_accum,
                        'grad_accum_step_incr_style': self.args.grad_accum_step_incr_style
                    },
                    'early_stopping': {
                        'use_early_stop': self.args.use_early_stopping,
                        'early_stopping_patience': self.args.early_stopping_patience,
                        'early_stopping_delta': self.args.early_stopping_delta,
                        'eval_interval': self.args.eval_interval
                    },
                    'learning_rate': {
                        'warmup_ratio': self.args.lr_warmup_ratio,
                        'decay_ratio': self.args.lr_decay_ratio,
                        'decay_style': self.args.lr_decay_style,
                        'decay_factor': self.args.lr_decay_factor,
                        'lr_begin': self.args.lr_begin,
                        'lr_max': self.args.lr_max,
                    }
                }
                with open(f'{self.eval_args.ckpt_performance_path}/info.json', 'w') as f:
                    json.dump(info, f, indent=4)

                '''
                # 保存柱状图
                settings = list(eval_setting.keys())
                returns = list(return_summary.values())
                safes = list(safe_summary.values())
                times = list(time_summary.values())
                
                fig = plt.figure(figsize=(12, 10))
                a1 = fig.add_subplot(3,1,1, label='a1')
                a2 = fig.add_subplot(3,1,2, label='a2')
                a3 = fig.add_subplot(3,1,3, label='a3')

                a1.bar(settings, returns, capsize=5)
                a1.set_xlabel('setting')
                a1.set_ylabel('return')
                for index, value in enumerate(returns):
                    a1.text(index, value, str(round(value, 2)), ha='center', va='bottom')
                
                a2.bar(settings, safes, capsize=5)
                a2.set_xlabel('setting')
                a2.set_ylabel('safe ratio')
                for index, value in enumerate(safes):
                    a2.text(index, value, str(round(value, 2)), ha='center', va='bottom')
                
                a3.bar(settings, times, capsize=5)
                a3.set_xlabel('setting')
                a3.set_ylabel('time')
                for index, value in enumerate(times):
                    a3.text(index, value, str(round(value, 2)), ha='center', va='bottom')
                
                plt.savefig(f'{base_path}/visualize/eval/result/{data_name}.png')
                '''