import json,time,os
import torch
import torch.utils.data as Data
from torch import nn, optim
import numpy as np
import shutil
from model import *
from utils import *
from data import *
from torch.optim.lr_scheduler import CosineAnnealingLR
from warmup_scheduler import GradualWarmupScheduler 
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
print('None')
time.sleep(5)


# CUDA
def setup_deterministic_mode():
    """
    
    """
    import os
    os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
    
    torch.manual_seed(42)
    np.random.seed(42)
    random.seed(42)
    
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
    # GPU
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(42)

# 
setup_deterministic_mode()


#  checkpoint 
def save_checkpoint(model, optimizer, epoch, iteration, dataloader_rng_state, filename, scheduler=None):
    checkpoint = {
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'epoch': epoch,
        'iteration': iteration,
        'rng_state': torch.get_rng_state(),
        'cuda_rng_state': torch.cuda.get_rng_state(),
        'numpy_rng_state': np.random.get_state(),
        'random_rng_state': random.getstate(),
        'dataloader_rng_state': dataloader_rng_state.get_state() if dataloader_rng_state is not None else None,
    }
    
    if scheduler is not None:
        checkpoint['scheduler_state_dict'] = scheduler.state_dict()  # Save state_dict()
    
    torch.save(checkpoint, filename)

def load_checkpoint(model, optimizer, filename, scheduler=None):
    checkpoint = torch.load(filename, weights_only=False)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch'] 
    iteration = checkpoint['iteration']

    # Restore all random states
    torch.set_rng_state(checkpoint['rng_state'])
    torch.cuda.set_rng_state(checkpoint['cuda_rng_state'])
    np.random.set_state(checkpoint['numpy_rng_state'])
    random.setstate(checkpoint['random_rng_state'])

    # Create and restore dataloader random state
    dataloader_rng_state = torch.Generator()
    if checkpoint['dataloader_rng_state'] is not None:
        dataloader_rng_state.set_state(checkpoint['dataloader_rng_state'])
    
    # Restore scheduler state
    if scheduler is not None and 'scheduler_state_dict' in checkpoint:
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        return model, optimizer, epoch, iteration, dataloader_rng_state, scheduler
    
    print('no scheduler_state')
    return model, optimizer, epoch, iteration, dataloader_rng_state



def train_step(args, model, train_data_loader, optimizer, criterion, device, logger, clip=1, scheduler=None):
    model.train()
    epoch_loss = 0
    total_samples = 0
    iter_loss_list = []
    max_eigenvalues = []
    lr = []
    # norm
    record_dict = {}
    for name, param in model.named_parameters():
        record_dict[name] = {}
        record_dict[name]['param_norm'] = []
        record_dict[name]['grad_norm'] = []
        record_dict[name]['vt_norm'] = []
        
    
    for i, (dec_inputs, dec_outputs) in enumerate(train_data_loader):  
        lr.append(optimizer.param_groups[0]['lr'])
        logger.info(f'epoch: {args.epoch} iter:{i} data: {dec_inputs[0]}')


        optimizer.zero_grad()
        dec_inputs, dec_outputs = dec_inputs.to(device), dec_outputs.to(device)
        outputs, _ = model(dec_inputs)

        batch_size = dec_inputs.size(0)  # 
        total_samples += batch_size


        if args.train_method == 'LTP':
            loss = criterion(outputs.view(batch_size, args.seq_len, args.vocab_size)[:,-1,:], dec_outputs[:,-1].view(-1))
        elif args.train_method == 'NTP':
            loss = criterion(outputs.view(batch_size * args.seq_len, args.vocab_size), dec_outputs.view(-1))
        epoch_loss += loss.item() * batch_size  # 
        iter_loss_list.append(loss.item())
        loss.backward()

        # save_iter_path = f'{args.working_dir}/model_iter'
        # save_iter_grad_path = f'{args.working_dir}/model_iter_grad'
        # if not os.path.exists(save_iter_path):
        #     os.makedirs(save_iter_path)
        # if not os.path.exists(save_iter_grad_path):
        #     os.makedirs(save_iter_grad_path)


        for name, param in model.named_parameters():
            if not param.requires_grad:
                continue
                
            # norm
            record_dict[name]['param_norm'].append(param.data.norm(2).item())
            
            # norm
            grad = param.grad
            record_dict[name]['grad_norm'].append(
                grad.data.norm(2).item() if grad is not None else 0.0
            )
            
            # Adam
            if grad is not None and (state := optimizer.state[param]):
                if 'exp_avg_sq' in state:
                    record_dict[name]['vt_norm'].append(
                        state['exp_avg_sq'].norm(2).item()
                    )
        
        

        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()


        # if args.epoch ==0:
        #     if scheduler is not None:
        #         tmp = {
        #         'model_state_dict': model.state_dict(),
        #         'optimizer_state_dict': optimizer.state_dict(),
        #         'scheduler_state_dict': scheduler.state_dict(),
        #         }
        #     else:
        #         tmp = {
        #         'model_state_dict': model.state_dict(),
        #         'optimizer_state_dict': optimizer.state_dict(),
        #         }

        #     torch.save(tmp, f'{save_iter_path}/model_iter{i}.pt')

    
        if scheduler is not None:
            scheduler.step()
    # torch.save(grad_dict, f'{save_iter_grad_path}/grad_{args.epoch}.pt')

    return epoch_loss / total_samples, max_eigenvalues, iter_loss_list, record_dict, lr  # 


def test_step(args, model, test_data_loader, criterion, device):
    model.eval()
    epoch_loss = 0
    total_samples = 0
    
    for i, (dec_inputs, dec_outputs) in enumerate(test_data_loader):
        dec_inputs, dec_outputs = dec_inputs.to(device), dec_outputs.to(device)
        outputs, _ = model(dec_inputs)
        
        batch_size = dec_inputs.size(0)  # 
        total_samples += batch_size
        
        if args.model == 'DNN' or args.model == 'DNN_averaged':
            loss = criterion(outputs.view(batch_size, args.vocab_size), dec_outputs[:,-1].view(-1))
        else:
            loss = criterion(outputs.view(batch_size, args.seq_len, args.vocab_size)[:,-1,:], dec_outputs[:,-1].view(-1))
        
        epoch_loss += loss.item() * batch_size  # 
    
    return epoch_loss / total_samples  # 



# 
def last_word_acc(args, model, data_loader):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.eval()
    correct = 0
    total_samples = 0
    
    for i, (dec_inputs, dec_outputs) in enumerate(data_loader):
        dec_inputs, dec_outputs = dec_inputs.to(device), dec_outputs.to(device)
        outputs, _ = model(dec_inputs)
        
        batch_size = dec_inputs.size(0)  # 
        total_samples += batch_size
        
        if args.model == 'DNN' or args.model == 'DNN_averaged':
            outputs = outputs.argmax(axis=-1).view(-1)
            correct += (outputs == dec_outputs[:, -1]).sum().item()
        else:
            outputs = outputs.argmax(axis=-1).view(-1, args.seq_len)
            correct += (outputs[:, -1] == dec_outputs[:, -1]).sum().item()
    
    return correct / total_samples

def last_word_devi(args, model, data_loader):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.eval()
    correct = 0
    total_samples = 0
    deviations = torch.tensor([], dtype=torch.long).to(device)
    
    for i, (dec_inputs, dec_outputs) in enumerate(data_loader):
        dec_inputs, dec_outputs = dec_inputs.to(device), dec_outputs.to(device)
        outputs, _ = model(dec_inputs)
        
        batch_size = dec_inputs.size(0)  # 
        total_samples += batch_size
        
        outputs = outputs.argmax(axis=-1).view(-1, args.seq_len)
        batch_deviations = outputs[:, -1] - dec_outputs[:, -1]
        deviations = torch.cat((deviations, batch_deviations), dim=0)
    unique_deviations, indices = torch.unique(deviations, return_inverse=True)
    deviation_counts = torch.bincount(indices)
    deviation_probs = deviation_counts.float() / total_samples
    
    return dict(zip(unique_deviations.cpu().numpy(), deviation_probs.cpu().numpy()))


def get_accuracy(args, model, data_loader_group, train_percent, test_percent, my_logger):
    '''
        acc，train_acc, test_acc, acc_list
    '''
    train_acc = 0
    test_acc = 0
    acc_list = []
    
    # acc
    if not args.target in ['composition_more_anchor', 'composition']:
        for i, data_name in enumerate(args.data_name):
            data_loader = data_loader_group[data_name]

            # 
            tmp_acc = last_word_acc(args, model, data_loader)
            acc_list.append(tmp_acc)

            if args.data_train[i] == 1:
                train_acc += tmp_acc * args.data_percent[i] / train_percent
            else:
                test_acc += tmp_acc * args.data_percent[i] / test_percent

            my_logger.info(f'data type: {data_name} \t Acc: {tmp_acc}')
    else:
        # for i, data_name in enumerate(args.data_name):
        data_name='43_xel'
        data_loader = data_loader_group[data_name]

        # 
        tmp_acc = last_word_acc(args, model, data_loader)
        acc_list.append(tmp_acc)

        # if args.data_train[i] == 1:
        #     train_acc += tmp_acc * args.data_percent[i] / train_percent
        # else:
        #     test_acc += tmp_acc * args.data_percent[i] / test_percent

        my_logger.info(f'data type: {data_name} \t Acc: {tmp_acc}')

    if args.target in ['composition_more_anchor', 'composition']:
        data_loader = data_loader_group['43_xel']
        deviation_dict = last_word_devi(args, model, data_loader)
        my_logger.info("Deviation Distribution:")
        for deviation, prob in deviation_dict.items():
            my_logger.info(f"  deviation: {deviation} \t Acc: {prob:.4f}")
        


    return train_acc, test_acc, acc_list



def _get_loss_of_each_data(args, model, data_loader_group, criterion, device):
    '''
        data_train=0loss，lossloss
        ，，0
    '''
    test_loss = 0
    total_samples = 0
    loss_list = []
    for i, data_name in enumerate(args.data_name):
        if args.data_train[i] == 0:
            data_loader = data_loader_group[data_name]
            tmp_loss = test_step(args, model, data_loader, criterion, device)
            loss_list.append(tmp_loss)

            total_samples += len(data_loader.dataset)
            test_loss += tmp_loss * len(data_loader.dataset)
        else:
            loss_list.append(0)
        
    test_loss = test_loss / total_samples

    return loss_list, test_loss






def train(args, datas, **kwargs):
    '''
    Required:
        args: 
        datas: 
    '''
    # 
    train_data_loader, generator, train_dataset = get_train_data(args, datas)
    torch.save(train_dataset, f'{args.working_dir}/data/train_dataset.pt')
    args.num_batches = len(train_data_loader)

    # data_loader
    data_loader_group = get_data_loader_group(args, datas)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    my_logger = Log(f'{args.working_dir}/train_log.log')

    for data_name in args.data_name:
        my_logger.info(f'data type: {data_name:<20} ex: {datas[data_name][0]}')
    

    
    # 
    model = get_model(args, device, **kwargs)

    # norm
    record_dict_his = {}
    name_list = []
    for name, param in model.named_parameters():
            record_dict_his[name] = {}
            name_list.append(name)
            record_dict_his[name]['param_norm'] = []
            record_dict_his[name]['grad_norm'] = []
            record_dict_his[name]['vt_norm'] = []

    # for name, module in model.named_modules():
    #     if isinstance(module, (nn.Linear, nn.LayerNorm)):
    #         module.register_forward_hook(save_activation_hook(name))
    my_logger.info(f'Total parameters: {sum(p.numel() for p in model.parameters())}')
    



    criterion = nn.CrossEntropyLoss(ignore_index=0).to(device)
    
    optimizer, scheduler = get_optimizer(model, args, **kwargs)

    # data_percent
    percent_list = np.array(args.data_percent)
    percent_list = percent_list / np.sum(percent_list)
    args.data_percent = percent_list.tolist()

    # 
    save_args = dict(vars(args))
    # kwargs
    for key, value in kwargs.items():
        save_args[key] = value
    for data_name in args.data_name:  # datasize
        save_args[f'data_size_{data_name}'] = len(datas[data_name])
    save_to_json_noindent(save_args, f'{args.working_dir}/config.json')


    # 
    np.savez(f'{args.working_dir}/data/datas.npz', **datas)

    # 
    for file in ['main.py', 'data.py', 'train.py', 'test.py', 'script.py']:
        shutil.copy(file, f'{args.working_dir}/src/{file}')
    for dir in ['utils', 'model', 'data_generator']:
        shutil.copytree(dir, f'{args.working_dir}/src/{dir}', dirs_exist_ok=True)    
    
    train_loss_his = []        # loss
    max_eigenvalue_his = []    # KL_loss
    train_loss_iter_his = []   # iterloss
    test_loss_his = []         # data_train=0loss
    group_loss_his = []        # loss，loss0（）

    acc_epoch_his = []    
    train_acc_his = []         # data_train=1accuracy(accuracy)
    test_acc_his = []          # data_train=0accuracy
    group_acc_his = []         # accuracy
    lr_recoder = []            # epoch


    # train datatest data
    train_percent, test_percent = 0, 0
    for i in range(len(args.data_name)):
        if args.data_train[i] == 1:
            train_percent += args.data_percent[i]
        else:
            test_percent += args.data_percent[i]

    load_epoch = 0
    if args.load_checkpoint is not None:
    # Load checkpoint and restore states
        if scheduler is not None:
            model, optimizer, load_epoch, iteration, generator_state, scheduler = load_checkpoint(
                model, optimizer, args.load_checkpoint, scheduler=scheduler
            )
        else:
            model, optimizer, load_epoch, iteration, generator_state = load_checkpoint(
                model, optimizer, args.load_checkpoint
            )
        
        # Recreate DataLoader with restored random state
        generator = torch.Generator()
        generator.manual_seed(42)
        generator.set_state(generator_state.get_state())
        train_data_loader = DataLoader(
            train_data_loader.dataset,
            batch_size=args.batch_size,
            shuffle=True,
            generator=generator,
            drop_last=True,
            num_workers=0, 
            pin_memory=False,
        )
        
        my_logger.info(f'Loaded checkpoint from epoch {load_epoch}')

    print('args.epoch', load_epoch)
    print('training...')
    torch.save(model.state_dict(), f'{args.working_dir}/model/model_ini.pt')
    for epoch in tqdm(range(args.n_epoch)):
        args.epoch = epoch + load_epoch
        # print('args.epoch', args.epoch)
        # accuracy
        if epoch % args.print_acc_epoch == 0 or epoch == args.n_epoch-1:
            train_acc, test_acc, acc_list = get_accuracy(args, model, data_loader_group, train_percent, test_percent, my_logger)  
        
            acc_epoch_his.append(epoch)
            train_acc_his.append(train_acc)
            test_acc_his.append(test_acc)
            group_acc_his.append(acc_list)

        # loss
        train_loss, max_eigenvalue, train_loss_iter, record_dict, lr = train_step(args, model, train_data_loader, optimizer, criterion, device, my_logger, args.clip, scheduler=scheduler)

        for name, param in model.named_parameters():
                record_dict_his[name]['param_norm'] += record_dict[name]['param_norm']
                record_dict_his[name]['grad_norm'] += record_dict[name]['grad_norm']
                record_dict_his[name]['vt_norm'] += record_dict[name]['vt_norm']
        tmp_loss_list, test_loss = _get_loss_of_each_data(args, model, data_loader_group, criterion, device)

        train_loss_his.append(train_loss)
        train_loss_iter_his += train_loss_iter
        max_eigenvalue_his.append(max_eigenvalue)
        group_loss_his.append(tmp_loss_list)
        test_loss_his.append(test_loss)
        lr_recoder += lr

        # 
        if epoch % args.print_loss_epoch == 0:
            my_logger.info(f'Epoch: {epoch:<5}  Train Loss: {train_loss:.4e}  Test Loss: {test_loss:.4e}')

        # 
        if (epoch % args.save_model_epoch == 0) or epoch == args.n_epoch-1:
            save_checkpoint(model, optimizer, epoch, 0, generator, f"{args.working_dir}/model/model_{epoch}.pt", scheduler)
            # if scheduler is not None:
            #     tmp = {
            #     'model_state_dict': model.state_dict(),
            #     'optimizer_state_dict': optimizer.state_dict(),
            #     'scheduler_state_dict': scheduler.state_dict(),
            #     }
            # else:
            #     tmp = {
            #     'model_state_dict': model.state_dict(),
            #     'optimizer_state_dict': optimizer.state_dict(),
            #     }
            # torch.save(tmp, f'{args.working_dir}/model/model_{epoch}.pt')
            # torch.save(grad_dict_his, f'{args.working_dir}/model/grad_dict_his.pt')
            # plot_param_norm(args.working_dir)


        if epoch > args.save_model_epoch + 1 and np.log10(train_loss_his[-1]) - np.log10(train_loss_his[-2]) > 2 and np.log10(train_loss) > -4:
            #  spike 
            my_logger.info(f'Epoch: {epoch:<5}  Train Loss Spike Detected! Previous {train_loss_his[-10]:2e} And Current {train_loss:2e}; Saved Model at model_spike_at_{epoch}.pt')
            # 
            my_logger.info(f'The gradient norms of all the parameters')
            for name, param in model.named_parameters():
                if param.requires_grad and param.grad is not None:
                    my_logger.info(f'Layer: {name}  Gradient Norm: {param.grad.data.norm(2).item()}')

            save_checkpoint(model, optimizer, epoch, 0, generator, f"{args.working_dir}/model/model_{epoch}_spike.pt", scheduler)

            # if scheduler is not None:
            #     tmp = {
            #     'model_state_dict': model.state_dict(),
            #     'optimizer_state_dict': optimizer.state_dict(),
            #     'scheduler_state_dict': scheduler.state_dict(),
            #     }
            # else:
            #     tmp = {
            #     'model_state_dict': model.state_dict(),
            #     'optimizer_state_dict': optimizer.state_dict(),
            #     }
            # torch.save(tmp, f'{args.working_dir}/model/model_{epoch}_spike.pt')
        

        # loss, acc
        if ((epoch % args.plot_loss_acc_epoch == 0) and (epoch != 0)) or (epoch == args.n_epoch-1):
            # loss
            np.save(f'{args.working_dir}/loss/train_loss_his.npy', np.array(train_loss_his))
            torch.save(max_eigenvalue_his, f'{args.working_dir}/loss/max_eigenvalue_his.pt')
            np.save(f'{args.working_dir}/loss/train_loss_iter_his.npy', np.array(train_loss_iter_his))
            np.save(f'{args.working_dir}/loss/test_loss_his.npy', np.array(test_loss_his))
            np.save(f'{args.working_dir}/loss/group_loss_his.npy', np.array(group_loss_his))
            np.save(f'{args.working_dir}/loss/acc_epoch_his.npy', np.array(acc_epoch_his))
            np.save(f'{args.working_dir}/loss/train_acc_his.npy', np.array(train_acc_his))
            np.save(f'{args.working_dir}/loss/test_acc_his.npy', np.array(test_acc_his))
            np.save(f'{args.working_dir}/loss/group_acc_his.npy', np.array(group_acc_his))
            np.save(f'{args.working_dir}/loss/lr_recoder.npy', np.array(lr_recoder))
            torch.save(record_dict_his, f'{args.working_dir}/loss/record_dict_his.pt')
            


            # loss
            plot_loss(args.working_dir)

            # maskunmaskacc
            plot_acc(args.working_dir)

            

            # acc
            if np.sum(args.data_show) != 0:
                plot_loss_of_each_data(args.working_dir)
                plot_acc_of_each_data(args.working_dir)

        if args.online_learning:
            print('online_learning')
            datas = get_data(args, **kwargs)
            my_logger.info(f'data type: {args.data_name[0]:<20} ex: {datas[args.data_name[0]][0]}')
        # quit()
            

    print('training finished!')



