import json
import torch
import torch.utils.data as Data
from torch import nn, optim
import numpy as np
import shutil
from model import *
from utils import *
from data import MyDataSet
from torch.optim.lr_scheduler import CosineAnnealingLR
from warmup_scheduler import GradualWarmupScheduler 



def train_step(args, model, train_data_loader, optimizer, criterion, device, clip=1, scheduler=None):
    model.train()
    epoch_loss = 0
    for i, (dec_inputs, dec_outputs) in enumerate(train_data_loader):  
        r'''
            dec_inputs: [batch_size, tgt_len]
            dec_outputs: [batch_size, tgt_len]
        '''
        optimizer.zero_grad()
        dec_inputs, dec_outputs = dec_inputs.to(device), dec_outputs.to(device)
        outputs, _ = model(dec_inputs) # # outputs: [batch_size * tgt_len, tgt_vocab_size]
        if args.model == 'DNN' or args.model == 'DNN_averaged':
            loss = criterion(outputs.view(args.batch_size, args.vocab_size), dec_outputs[:,-1].view(-1))
        else:
            loss = criterion(outputs.view(args.batch_size, args.seq_len, args.vocab_size)[:,-1,:], dec_outputs[:,-1].view(-1))

        epoch_loss += loss.item()
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)

        optimizer.step()
    
    if scheduler is not None:
        scheduler.step()

    return epoch_loss / len(train_data_loader)


def test_step(args, model, test_data_loader, criterion, device):
    model.eval()
    epoch_loss = 0
    for i, (dec_inputs, dec_outputs) in enumerate(test_data_loader):  
        dec_inputs, dec_outputs = dec_inputs.to(device), dec_outputs.to(device)
        outputs, _ = model(dec_inputs) 
        if args.model == 'DNN' or args.model == 'DNN_averaged':
            loss = criterion(outputs.view(args.batch_size, args.vocab_size), dec_outputs[:,-1].view(-1))
        else:
            loss = criterion(outputs.view(args.batch_size, args.seq_len, args.vocab_size)[:,-1,:], dec_outputs[:,-1].view(-1))
        epoch_loss += loss.item()

    return epoch_loss / len(test_data_loader)


def last_word_acc(args, model, data, seq_len, batch_size):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.eval()
    correct = 0
    train_dataset = MyDataSet(data)
    data_loader = Data.DataLoader(train_dataset, shuffle=True, batch_size=batch_size, 
                                        drop_last=False, collate_fn=train_dataset.padding_batch)
    
    for i, (dec_inputs, dec_outputs) in enumerate(data_loader):  
        dec_inputs, dec_outputs = dec_inputs.to(device), dec_outputs.to(device)
        outputs, _ = model(dec_inputs) 
        if args.model == 'DNN' or args.model == 'DNN_averaged':
            outputs = outputs.argmax(axis=-1).view(-1)
            correct += (outputs == dec_outputs[:, -1]).sum().item()
        else:
            outputs = outputs.argmax(axis=-1).view(-1, seq_len)
            correct += (outputs[:, -1] == dec_outputs[:, -1]).sum().item()

    return correct / len(data_loader.dataset) 


def get_accuracy(args, model, datas, mask_percent, unmask_percent, my_logger):
    acc_train, acc_test = [], []
            
    acc_train_mask, acc_test_mask = 0, 0
    acc_train_unmask, acc_test_unmask = 0, 0
    
    for i, data_name in enumerate(args.data_name):
        train_seq_group = datas['train_seq_group'][data_name].tolist()
        test_seq_group = datas['test_seq_group'][data_name].tolist()

        if train_seq_group == []:
            tmp_train_acc = 0
        else:
            tmp_train_acc = last_word_acc(args, model, train_seq_group, args.seq_len, args.batch_size)
        
        if test_seq_group == []:
            tmp_test_acc = 0
        else:
            tmp_test_acc = last_word_acc(args, model, test_seq_group, args.seq_len, args.batch_size)
    
        acc_train.append(tmp_train_acc)
        acc_test.append(tmp_test_acc)

        my_logger.info(f'data type: {data_name} \tTrain Acc: {tmp_train_acc} \tTest Acc: {tmp_test_acc}')

        if args.data_mask[i] == 0:
            acc_train_unmask += tmp_train_acc * args.data_percent[i] / unmask_percent
            acc_test_unmask += tmp_test_acc * args.data_percent[i] / unmask_percent
        else:
            acc_train_mask += tmp_train_acc * args.data_percent[i] / mask_percent
            acc_test_mask += tmp_test_acc * args.data_percent[i] / mask_percent    
    
    return acc_train, acc_test, acc_train_unmask, acc_test_unmask, acc_train_mask, acc_test_mask



def _get_loss_mask(args, model, datas, criterion, device, mode = 'train'):
    loss_mask = 0
    data_num_mask = 0
    for i, data_name in enumerate(args.data_name):
        if args.data_mask[i] == 1:
            seq_group = datas[f'{mode}_seq_group'][data_name].tolist()
            # print(len(seq_group))
            if seq_group == []:
                tmp_loss = 1e-9
            else:
                tmp_dataset = MyDataSet(seq_group)
                data_loader = Data.DataLoader(tmp_dataset, shuffle=True, batch_size=args.batch_size, 
                                                    drop_last=True, collate_fn=tmp_dataset.padding_batch)
                tmp_loss = test_step(args, model, data_loader, criterion, device)

            
            data_num_mask += len(data_loader.dataset)
            loss_mask += tmp_loss * len(data_loader.dataset)
        
    if data_num_mask != 0:
        loss_mask = loss_mask / data_num_mask
    else:
        loss_mask = 1e-9

    return loss_mask






def train(args, datas, **kwargs):


    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
    model = get_model(args, device, **kwargs)

    tmp_dir = 'result/GPT_2_step_reasoning/single_chain_search-seed_1-N_200000-3L1H_shown_in_paper'
    state_dict=torch.load(f'{tmp_dir}/model/model_0.pt')
    model.load_state_dict(state_dict)

    total_params = sum(p.numel() for p in model.parameters())
    print(f'Total parameters: {total_params}')

    criterion = nn.CrossEntropyLoss(ignore_index=0).to(device)
    
    optimizer, scheduler = get_optimizer(model, args, **kwargs)

    scheduler.step()


    my_logger = Log(f'{args.working_dir}/train_log.log')

    percent_list = np.array(args.data_percent)
    percent_list = percent_list / np.sum(percent_list)
    args.data_percent = percent_list.tolist()

    save_args = dict(vars(args))
    for key, value in kwargs.items():
        save_args[key] = value
    for data_name in args.data_name:  
        save_args[f'train_datasize_{data_name}'] = len(datas['train_seq_group'][data_name])
        save_args[f'test_datasize_{data_name}'] = len(datas['test_seq_group'][data_name])
    save_to_json_noindent(save_args, f'{args.working_dir}/config.json')

    
    # for file in ['main.py', 'data.py', 'train.py', 'test.py', 'script.py']:
    #     shutil.copy(file, f'{args.working_dir}/src/{file}')
    # for dir in ['utils', 'model', 'data_generator']:
    #     shutil.copytree(dir, f'{args.working_dir}/src/{dir}', dirs_exist_ok=True)    

    train_loss_his = []
    test_loss_his = []
    train_loss_mask_his = []
    test_loss_mask_his = []
    acc_epoch_his = []
    train_acc_his = []
    test_acc_his = []


    mask_percent, unmask_percent = 0, 0
    for i in range(len(args.data_name)):
        if args.data_mask[i] == 1:
            mask_percent += args.data_percent[i]
        else:
            unmask_percent += args.data_percent[i]
    
    if mask_percent != 0:
        acc_train_mask_his, acc_test_mask_his = [], []
    acc_train_unmask_his, acc_test_unmask_his = [], []

    print('training...')
    for epoch in range(1, args.n_epoch+1):
        
        if epoch % args.print_acc_epoch == 0 or epoch == args.n_epoch-1:
            acc_train, acc_test, acc_train_unmask, acc_test_unmask, acc_train_mask, acc_test_mask \
                = get_accuracy(args, model, datas, mask_percent, unmask_percent, my_logger)  
        
            acc_epoch_his.append(epoch)
            train_acc_his.append(acc_train)
            test_acc_his.append(acc_test)
            if mask_percent != 0:
                acc_train_mask_his.append(acc_train_mask)
                acc_test_mask_his.append(acc_test_mask)
            acc_train_unmask_his.append(acc_train_unmask)
            acc_test_unmask_his.append(acc_test_unmask)

            my_logger.info(f'Train Acc Unmask: {acc_train_unmask} \tTest Acc Unmask: {acc_test_unmask}')
            if mask_percent != 0:
                my_logger.info(f'Train Acc Mask: {acc_train_mask} \tTest Acc Mask: {acc_test_mask}')
        
        train_loss = train_step(args, model, datas['train_data_loader'], optimizer, criterion, device, args.clip, scheduler=scheduler)
        test_loss = test_step(args, model, datas['test_data_loader'], criterion, device)

        train_loss_his.append(train_loss)
        test_loss_his.append(test_loss)

        train_loss_mask = _get_loss_mask(args, model, datas, criterion, device, mode='train')
        test_loss_mask = _get_loss_mask(args, model, datas, criterion, device, mode='test')

        train_loss_mask_his.append(train_loss_mask)
        test_loss_mask_his.append(test_loss_mask)

        if epoch % args.print_loss_epoch == 0:
            my_logger.info(f'Epoch: {epoch:<5}  Train Loss: {train_loss:.4e}  Test Loss: {test_loss:.4e}')

        if (epoch % args.save_model_epoch == 0) or epoch == args.n_epoch-1:
            torch.save(model.state_dict(), f'{args.working_dir}/model/model_{epoch}.pt')
        

        if ((epoch % args.plot_loss_acc_epoch == 0) and (epoch != 0)) or (epoch == args.n_epoch-1):
            # 保存loss
            np.save(f'{args.working_dir}/loss/train_loss_his.npy', np.array(train_loss_his))
            np.save(f'{args.working_dir}/loss/test_loss_his.npy', np.array(test_loss_his))
            np.save(f'{args.working_dir}/loss/train_loss_mask_his.npy', np.array(train_loss_mask_his))
            np.save(f'{args.working_dir}/loss/test_loss_mask_his.npy', np.array(test_loss_mask_his))
            np.save(f'{args.working_dir}/loss/acc_epoch_his.npy', np.array(acc_epoch_his))
            np.save(f'{args.working_dir}/loss/train_acc_his.npy', np.array(train_acc_his))
            np.save(f'{args.working_dir}/loss/test_acc_his.npy', np.array(test_acc_his))
            if mask_percent != 0:
                np.save(f'{args.working_dir}/loss/acc_train_mask_his.npy', np.array(acc_train_mask_his))
                np.save(f'{args.working_dir}/loss/acc_test_mask_his.npy', np.array(acc_test_mask_his))
            np.save(f'{args.working_dir}/loss/acc_train_unmask_his.npy', np.array(acc_train_unmask_his))
            np.save(f'{args.working_dir}/loss/acc_test_unmask_his.npy', np.array(acc_test_unmask_his))

            plot_loss_of_mask_unmask_data(args.working_dir, x_axis='epoch')

            plot_acc_of_mask_unmask_data(args.working_dir)

            if np.sum(args.data_show) != 0:
                plot_acc_of_each_data(args.working_dir)

    print('training finished!')



