import torch
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import DataLoader
from collections import  OrderedDict
import os
import tqdm
import copy
import torch.distributed as dist
import sys
sys.path.append('../')

from model import GPTLM, GPT
from .optim_schedule import ScheduledOptim
from utils._record import _PerformanceMeter
from utils.tool import *


class GPTTrainer:
    """
    GPTTrainer make the pretrained GPT model with two LM training method.

        1. Masked Language Model : 3.3.1 Task #1: Masked LM
        2. Next Sentence prediction : 3.3.2 Task #2: Next Sentence Prediction

    """

    def __init__(self, model, args: dict,
                 train_dataloader: DataLoader, test_dataloader: DataLoader = None, val_dataloader: DataLoader = None,
                 lr: float = 1e-4, weight_decay: float = 0.01, warmup_steps=10000,
                 with_cuda: bool = True, cuda_devices=None, log_freq: int = 10, optimizer=None, scheduler=None,
                 task_dict=None):
        """
        :param gpt: GPT model which you want to train
        :param vocab_size: total word vocab size
        :param train_dataloader: train dataset data loader
        :param test_dataloader: test dataset data loader [can be None]
        :param lr: learning rate of optimizer
        :param betas: Adam optimizer betas
        :param weight_decay: Adam optimizer weight decay param
        :param with_cuda: traning with cuda
        :param log_freq: logging frequency of the batch iteration
        """
        np.set_printoptions(suppress=True)
        cuda_condition = torch.cuda.is_available() and with_cuda
        self.local=args['Local']
        self.save_dir = args['save_dir']
        self.device = torch.device("cuda:0" if cuda_condition else "cpu")
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.task_dict = task_dict
        self.task_name = list(task_dict.keys())
        self.task_num = len(task_dict)
        self.global_step = 0
        self.max_len = args["max_len"]
        self.predict_max_len = args["predict_max_len"]
        self.info_onehot_cnt = args['info_element_cnt'] + 1
        self.voiceTimes_onehot_cnt = args['voiceTimes_cnt']
        self.action_onehot_cnt = args["action_element_cnt"]
        
        self.model = model

        self.train_dataloader = train_dataloader
        self.test_dataloader = test_dataloader
        self.val_dataloader = val_dataloader

        self.meter = _PerformanceMeter(self.task_dict, self.device)
        self.test_meter = [_PerformanceMeter(self.task_dict, self.device) for _ in range(5)]

        self.log_freq = log_freq

        print("Total Parameters:", sum([p.nelement() for p in self.model.parameters()]))


        if args['if_load_model']==1:
            load_path = "TBT/" + args['load_model']
            print('load path:',load_path)
            MAX_RETRIES = 3
            retry_count = 0
            while True:
                try:
                    retry_count += 1
                    buffer = BytesIO(bucket.get_object(load_path).read())
                    break
                except Exception:
                    if retry_count >= MAX_RETRIES:
                        raise

            checkpoint = torch.load(buffer)
            self.model.load_state_dict(checkpoint)

            print('load model finished')

            self.global_step_init = args['global_step']
            self.global_step = args['global_step']

            print('init global step: ',self.global_step_init)

            for i in range(self.global_step_init):
                self.scheduler.step()
            print('schedular after init: ',self.scheduler.get_lr())

    def train(self, args):
        epoch = args['epochs']
        train_loader, train_batch = self._prepare_dataloaders(self.train_dataloader)
                           
        self.test(self.test_dataloader, epoch, mode='test', writer = None, global_step=self.global_step)

        save_path = os.path.join(self.save_dir, "model_ckpt_{}.pt".format(self.global_step))
        
        for epoch in range(epoch):
            self.model.train()
            self.meter.record_time('begin')
            for batch_index in range(train_batch):
                self.model.train()
                data = self._process_data(train_loader)
                data = {key: value.to(self.device) for key, value in data.items()}
                train_labels = self._get_labels(data)

                # print(batch_index, data['random_mask'][0])

                mask_lm_output = self.model.forward(data, True)
                train_preds = self.process_preds(mask_lm_output)

                train_losses = self._compute_loss(train_preds, train_labels, self.global_step, data)
                # self.scheduler.zero_grad()
                # self.backward(train_losses, self.global_step)
                # self.scheduler.step_and_update_lr()
                
                self.optimizer.zero_grad()
                self.backward(train_losses, self.global_step)
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1)
                self.optimizer.step()
                self.scheduler.step()

                self.global_step += 1

                if self.global_step%100 == 0:
                    if (self.local or (not self.local and int(os.environ.get("RANK")) == 0)):
                        dt = get_time()
                        print('train_losses  ', self.global_step, dt, torch.clone(train_losses).detach().cpu().numpy().round(5))
                        print('model para size', parameters_size(self.model).detach().cpu().numpy().round(5))
                        print('model grad size', grad_size(self.model).detach().cpu().numpy().round(5))

                if self.global_step/10000>0 and self.global_step%10000==0:
                    if (self.local or (not self.local and int(os.environ.get("RANK")) == 0)):
                        print('global_step', self.global_step)
                        for i,j in enumerate(self.optimizer.state_dict()['param_groups']):
                            print('lr: group ', i, j['lr'])
                        dt = get_time()
                        print('train_losses  ', torch.clone(train_losses).detach().cpu().numpy().round(5))
                        print('model para size', parameters_size(self.model).detach().cpu().numpy().round(5))
                        print('model grad size', grad_size(self.model).detach().cpu().numpy().round(5))
                        self.meter.update(train_preds, train_labels, data, None, data['ds_to_sub_end'])
                        self.meter.record_time('end')
                        self.meter.get_score()
                        if (self.local or (not self.local and int(os.environ.get("RANK")) == 0)):
                            self.meter.display(epoch=epoch, mode='train', writer = None, global_step=self.global_step)
                        self.meter.reinit()
                        
                        # self.test(self.val_dataloader, epoch, mode='val', return_improvement=False, global_step=self.global_step, writer = None)
                        self.test(self.test_dataloader, epoch, mode='test', writer = None, global_step=self.global_step)
                        # self.test(self.val_dataloader, epoch, mode='val', writer = None, global_step=self.global_step)

                        save_path = os.path.join(self.save_dir, "model_ckpt_{}.pt".format(self.global_step))
                        save_to_oss(save_path, self.model)

    def test(self, test_dataloader, epoch=None, mode='test', return_improvement=False, global_step=None, batch_data=None, writer=None):
        self.model.eval()
        for i in range(5):
            self.test_meter[i].record_time('begin')
        test_dataloader, test_batch = self._prepare_dataloaders(test_dataloader)
        with torch.no_grad():
            for batch_index in range(test_batch):
                data = self._process_data(test_dataloader)
                data = {key: value.to(self.device) for key, value in data.items()}
                test_labels = self._get_labels(data)
                play_mode_flag = data['ori_play_mode']
                mask_lm_output = self.model.forward(data, False)
                test_preds = self.process_preds(mask_lm_output)

                self._compute_trigger_metrics(test_preds, test_labels, data)
                for i in [0,1,2,3,4]:
                    if i == 4:
                        new_test_data,new_test_preds, new_test_labels = self._reduce_test_data(data, 1, test_preds, test_labels)
                        test_losses = self._compute_eval_loss(new_test_preds, new_test_labels,i, self.global_step, new_test_data)
                        self.test_meter[i].update(new_test_preds, new_test_labels, new_test_data, None, new_test_data['ds_to_sub_end'])
                    else:
                        reduce_flag = torch.where(play_mode_flag==i, 1, 0)
                        new_test_data,new_test_preds, new_test_labels = self._reduce_test_data(data, reduce_flag, test_preds, test_labels)
                        test_losses = self._compute_eval_loss(new_test_preds, new_test_labels, i, self.global_step, new_test_data)
                        self.test_meter[i].update(new_test_preds, new_test_labels, new_test_data, None, new_test_data['ds_to_sub_end'])
            for i in [0,1,2,3,4]:
                self.test_meter[i].record_time('end')
                self.test_meter[i].get_score()
                if (self.local or (not self.local and int(os.environ.get("RANK")) == 0)):
                    self.test_meter[i].display(epoch=epoch, mode=mode, flag=i, writer = writer, global_step=global_step)
                self.test_meter[i].reinit()
            self._reset_trigger_metrics()
    
    def _compute_trigger_metrics(self, preds, labels, data):
        '''
        Domain prior knowledge
        '''


    def _reduce_test_data(self, data, reduce_flag, pred, gt):
        new_pred = copy.deepcopy(pred)
        new_gt = copy.deepcopy(gt)
        new_data = copy.deepcopy(data)
        new_data['random_mask'] = new_data['random_mask']*reduce_flag
        new_data['padding_mask'] = new_data['padding_mask']*reduce_flag
        return new_data,new_pred,new_gt       

    def _compute_eval_loss(self, preds, gts, i, global_step=None, batch_data=None):
        train_losses = torch.zeros(self.task_num, self.max_len).to(self.device)
        for tn, task in enumerate(self.task_name):
            if task in ('action','trigger'):
                train_losses[tn] = self.test_meter[i].losses[task]._update_loss(preds[task], gts[task], task, global_step, batch_data)
            else:
                train_losses[tn] = self.test_meter[i].losses[task]._update_loss(preds[task], gts[task], batch_data)
                
        return train_losses

    
    def predict(self,task_list, test_dataloader):
        r'''The test process of multi-task learning.
        '''
        self.model.eval()
        task_result = {}
        task_target = {}
        for task in task_list:
            task_result[task] = []
            task_target[task] = []
        for i in ['caseid', 'path_id', 'seg_id', 'time', 'taskid_ori', 'nearest_label_gpst', 'scene', 'mask_index', 'ds_to_sub_end']:
            task_result[i] =[]
            task_target[i] = []

        test_dataloader, test_batch = self._prepare_dataloaders(test_dataloader)
        print('test_batch', test_batch)

        with torch.no_grad():
            for batch_index in range(test_batch):
                test_inputs = self._process_data(test_dataloader)
                test_inputs = {key: value.to(self.device) for key, value in test_inputs.items()}
                test_labels = self._get_labels(test_inputs)
                mask_lm_output = self.model.forward(test_inputs, False)
                test_preds = self.process_preds(mask_lm_output)
                batch_size = test_labels['trigger'].size(0)
                seq_len = test_labels['trigger'].size(1)
                
                token_mask = test_inputs['random_mask'].reshape(batch_size, seq_len)
                padding_mask = test_inputs['padding_mask'].reshape(batch_size, seq_len)
                all_mask = token_mask * padding_mask

                print(test_batch, batch_index)

                ds_to_sub_end = test_inputs['ds_to_sub_end'][torch.where(all_mask>0)]

                pred_trigger = list(np.array((test_preds['trigger'][torch.where(all_mask>0)].view(-1)*(ds_to_sub_end.view(-1))).to('cpu')))
                real_trigger = list(np.array((test_labels['trigger'][torch.where(all_mask>0)]*(ds_to_sub_end.view(-1))).to('cpu')))
                
                pred_action = [str(x)[1:-1] for x in np.array((test_preds['action'][torch.where(all_mask>0)]>0.5).long().to('cpu'))]
                real_action = [str(x)[1:-1] for x in np.array(test_labels['action'][torch.where(all_mask>0)].long().to('cpu'))]

                pred_info = list(np.array(test_preds['info'][torch.where(all_mask>0)].max(-1)[1].long().to('cpu')))
                real_info = list(np.array(test_labels['info'][torch.where(all_mask>0)].view(-1).long().to('cpu')))

                pred_voiceTimes = list(np.array(test_preds['voiceTimes'][torch.where(all_mask>0)].max(-1)[1].long().to('cpu')))
                real_voiceTimes = list(np.array(test_labels['voiceTimes'][torch.where(all_mask>0)].long().to('cpu')))


                task_result['trigger'].extend(pred_trigger)
                task_result['action'].extend(pred_action)
                task_result['info'].extend(pred_info)
                task_result['voiceTimes'].extend(pred_voiceTimes)

                task_target['trigger'].extend(real_trigger)
                task_target['action'].extend(real_action)
                task_target['info'].extend(real_info)
                task_target['voiceTimes'].extend(real_voiceTimes)

        return task_result, task_target


    def seq_loss_recursive_predict(self,task_list, test_dataloader):
        r'''The test process of multi-task learning.
        '''
        self.model.eval()
        task_result = {}
        task_target = {}
        for task in task_list:
            task_result[task] = []
            task_target[task] = []
        for i in ['caseid', 'path_id', 'seg_id', 'time', 'taskid_ori', 'nearest_label_gpst', 'scene', 'mask_index', 'ds_to_sub_end']:
            task_result[i] =[]
            task_target[i] = []

        test_dataloader, test_batch = self._prepare_dataloaders(test_dataloader)
        print('test_batch', test_batch)

        with torch.no_grad():
            for batch_index in range(test_batch):

                test_inputs = self._process_data(test_dataloader)
                test_inputs = {key: value.to(self.device) for key, value in test_inputs.items()}
                test_labels = self._get_labels(test_inputs)

                print('trigger', test_labels['trigger'].shape)

                batch_size = test_labels['trigger'].size(0)
                seq_len = self.max_len
                

                # predict step by step
                for i in range(self.predict_max_len):
                    # padding 0
                    current_inputs = copy.deepcopy(test_inputs)
                    current_labels = copy.deepcopy(test_labels)
                    if i<3:
                        for k, v in current_inputs.items():
                            if len(v.shape) > 1:
                                current_inputs[k] = v[:,:self.max_len]

                        for k, v in current_labels.items():
                            if len(v.shape) > 1:
                                current_labels[k] = v[:,:self.max_len]
                        
                        current_index = i
                    else:
                        for k, v in current_inputs.items():
                            if len(v.shape) > 1:
                                current_inputs[k] = v[:,i-self.max_len+1:i+1]

                        for k, v in current_labels.items():
                            if len(v.shape) > 1:
                                current_labels[k] = v[:,i-self.max_len+1:i+1]
                        current_index = 2

                    current_inputs["random_mask"] = torch.zeros_like(current_inputs['random_mask'])
                    current_inputs["random_mask"][:,current_index] = 1
                    current_inputs["padding_mask"][:,current_index+1:] = 0
                    # valid predict position
                    token_mask = current_inputs["random_mask"].reshape(batch_size, seq_len)
                    padding_mask =  current_inputs["padding_mask"].reshape(batch_size, seq_len)
                    all_mask = token_mask * padding_mask

                    current_inputs["play_info"][:,current_index] = torch.ones(self.info_onehot_cnt) * -1
                    current_inputs["play_voiceTimes"][:,current_index] = torch.ones(self.voiceTimes_onehot_cnt) * -1
                    current_inputs["play_action"][:,current_index] = torch.ones(self.action_onehot_cnt) * -1
                    current_inputs["play_trigger"][:,current_index] = torch.FloatTensor([-1])

                    # current prediction position
                    mask_lm_output = self.model.forward(current_inputs, False)
                    test_preds = self.process_preds(mask_lm_output)

                    # current prediction position replace to test_inputs's y value
                    test_inputs["play_info"][:,i] = self.one_hot(test_preds['info'].argmax(-1)[:,current_index] , self.info_onehot_cnt)
                    test_inputs["play_voiceTimes"][:,i] = self.one_hot(test_preds['voiceTimes'].argmax(-1)[:,current_index] , self.voiceTimes_onehot_cnt)
                    test_inputs["play_action"][:,i] = (test_preds['action'] > 0.5).long()[:,current_index]
                    test_inputs["play_trigger"][:,i] = test_preds['trigger'][:,current_index,0]

                    # record current prediction position valid predict value
                    ds_to_sub_end_record = list(np.array((current_inputs['ds_to_sub_end'][torch.where(all_mask>0)].long().view(-1)).to('cpu')))
                    
                    ds_to_sub_end = current_inputs['ds_to_sub_end'][torch.where(all_mask>0)]
                    pred_trigger = list(np.array((test_preds['trigger'][torch.where(all_mask>0)].view(-1)*(ds_to_sub_end.view(-1))).to('cpu')))
                    real_trigger = list(np.array((current_labels['trigger'][torch.where(all_mask>0)]*(ds_to_sub_end.view(-1))).to('cpu')))
                    
                    pred_action = [str(x)[1:-1] for x in np.array((test_preds['action'][torch.where(all_mask>0)]>0.5).long().to('cpu'))]
                    real_action = [str(x)[1:-1] for x in np.array(current_labels['action'][torch.where(all_mask>0)].long().to('cpu'))]

                    pred_info = list(np.array(test_preds['info'].max(-1)[1].long()[torch.where(all_mask>0)].to('cpu')))
                    real_info = list(np.array(current_labels['info'][torch.where(all_mask>0)].view(-1).long().to('cpu')))

                    pred_voiceTimes = list(np.array(test_preds['voiceTimes'].max(-1)[1].long()[torch.where(all_mask>0)].to('cpu')))
                    real_voiceTimes = list(np.array(current_labels['voiceTimes'][torch.where(all_mask>0)].long().to('cpu')))

                    task_result['trigger'].extend(pred_trigger)
                    task_result['action'].extend(pred_action)
                    task_result['info'].extend(pred_info)
                    task_result['voiceTimes'].extend(pred_voiceTimes)
                    
                    task_target['trigger'].extend(real_trigger)
                    task_target['action'].extend(real_action)
                    task_target['info'].extend(real_info)
                    task_target['voiceTimes'].extend(real_voiceTimes)
                    task_result['ds_to_sub_end'].extend(ds_to_sub_end_record)

                    caseid = list(np.array(current_inputs['caseid'].unsqueeze(-1).repeat(1,self.max_len)[torch.where(all_mask>0)].view(-1).to('cpu')))
                    path_id = list(np.array(current_inputs['path_id'].unsqueeze(-1).repeat(1,self.max_len)[torch.where(all_mask>0)].view(-1).to('cpu')))
                    seg_id = list(np.array(current_inputs['seg_id'].unsqueeze(-1).repeat(1,self.max_len)[torch.where(all_mask>0)].view(-1).to('cpu')))
                    nearest_label_gpst = caseid
                    taskid_ori = list(np.array(current_inputs['ori_play_mode'][torch.where(all_mask>0)].view(-1).to('cpu')))
                    taskid_ori = list(map(int, taskid_ori))
                    time = caseid
                    scene = list(np.array(current_inputs['scene'].unsqueeze(-1).repeat(1,self.max_len)[torch.where(all_mask>0)].long().view(-1).to('cpu')))
                    
                    task_result['caseid'].extend(caseid)
                    task_result['path_id'].extend(path_id)
                    task_result['seg_id'].extend(seg_id)
                    task_result['time'].extend(time)
                    task_result['taskid_ori'].extend(taskid_ori)
                    task_result['nearest_label_gpst'].extend(nearest_label_gpst)
                    task_result['scene'].extend(scene)
                    task_result['mask_index'].extend(caseid)

                print(test_batch, batch_index)

        return task_result, task_target

    
    def _process_data(self, loader):
        try:
            data_batch = next(loader[1])
        except:
            loader[1] = iter(loader[0])
            data_batch = next(loader[1])
        return data_batch
    
    def _prepare_dataloaders(self, dataloaders):
        loader = [dataloaders, iter(dataloaders)]
        return loader, len(dataloaders)

    def _get_labels(self, data):
        label = {}
        for task in self.task_name:
            label[task] = data[task]
        return label
    
    def process_preds(self, preds, task_name=None):
        for task in self.task_name:
            if task in ('trigger','action'):
                preds[task] = torch.sigmoid(preds[task])
            else:
                preds[task] = torch.softmax(preds[task],dim=-1)
        return preds
    
    def backward(self, losses, global_step):
        cnt = losses.shape[0]
        loss = torch.pow(losses.prod(0) + 1e-7, 1./cnt)
        loss = loss.prod()

        loss.backward()

    
    def _compute_loss(self, preds, gts, global_step=None, batch_data=None):
        train_losses = torch.zeros(self.task_num, self.max_len).to(self.device)
        for tn, task in enumerate(self.task_name):
            if task in ('action','trigger'):
                train_losses[tn] = self.meter.losses[task]._update_loss(preds[task], gts[task], task, global_step, batch_data)
            else:
                train_losses[tn] = self.meter.losses[task]._update_loss(preds[task], gts[task], batch_data)
      
        return train_losses

    def one_hot(self, labels, num):
        seq_len = labels.size(0)
        onehot = torch.LongTensor(np.eye(num)[labels.cpu().numpy().reshape(-1)]).reshape(seq_len, -1)
        return onehot
