import json
import torch
import torch.utils.data as Data
from torch import nn, optim
import numpy as np
import shutil
from model import *
from utils import *
from data import *
from torch.optim.lr_scheduler import CosineAnnealingLR
from warmup_scheduler import GradualWarmupScheduler 
from collections.abc import Iterable

class ExponentialLoss(nn.Module):
    def __init__(self, ignore_index=None):
        super().__init__()
        self.ignore_index = ignore_index

    def forward(self, outputs, labels):
        # 过滤需要忽略的样本（如标签为0）
        if self.ignore_index is not None:
            mask = labels != self.ignore_index
            outputs = outputs[mask]
            labels = labels[mask]
            if outputs.numel() == 0:  # 所有样本都被忽略
                return torch.tensor(0.0, device=outputs.device)
        
        # 将标签转换为±1（假设原标签为0和1）
        y = 2 * labels - 1  # 0→-1, 1→1
        # 获取模型输出的最后一个元素（假设输出形状为[batch_size, 1]）
        f_x = outputs.squeeze()  # 去除多余维度
        # print(f_x.shape)
        # print(y.shape)
        # 计算q和损失
        q = y * f_x
        loss = torch.exp(-q).mean()  # 平均损失
        return loss

def calculate_F_norm(model:nn.Module, model0_params: dict):
    diff = {}
    with torch.no_grad():
        for name, params in model.named_parameters():
            if name in model0_params.keys():
                diff[name] = torch.linalg.norm(params - model0_params[name], ord = 'fro').item()
    return diff

def concat_diff(diff, diffs = None):
    if not diffs:
        diffs = {
            key: [] for key in diff.keys()
        }
    
    for key in diff:
        if isinstance(diff[key], Iterable):
            diffs[key].extend(diff[key])
        else:
            diffs[key].append(diff[key])
    
    return diffs

def normalize_vectorgroup(vector_group):
    norms = np.linalg.norm(vector_group,axis=1)
    # print(norms)
    mask = norms > 0
    vector_masked = vector_group[mask]
    
    # print(vector_masked.shape)
    norms = norms[mask]
    # print(vector_masked)
    norms = norms[:, np.newaxis]
    vector_normalized = vector_masked / norms
    return vector_normalized,vector_masked.shape[0]

def cosine_similarity_array(X):
    X = X / np.linalg.norm(X, axis=1, keepdims=True)
    return np.dot(X, X.T)

def seperate_vectors_by_eigenvector(vector_group):
    mask = np.linalg.norm(vector_group,axis=1) > 0
    vector_group = vector_group[mask]
    similarity_matrix = np.dot(vector_group,vector_group.transpose())
    w,v = np.linalg.eig(similarity_matrix)
    index = np.argmax(w)
    tmpeig = v[:,index]
    order_mask = np.argsort(tmpeig)
    
    similarity_matrix = similarity_matrix[order_mask,:]
    similarity_matrix = similarity_matrix[:,order_mask]
    return similarity_matrix,order_mask

def plot_weight_heatmap_eigen(weight, save_path):
    if weight.data.shape[1] != 1:
        weight_normalized,masked_shape = normalize_vectorgroup(weight)
        similarity_matrix,order = seperate_vectors_by_eigenvector(weight_normalized)
        # 创建画布和坐标轴
        fig = plt.figure(frameon=False)  # 关闭画布边框
        ax = plt.Axes(fig, [0, 0, 1, 1])  # 坐标轴铺满整个画布（无留白）
        ax.set_axis_off()  # 关闭坐标轴
        fig.add_axes(ax)
        meshreturn = ax.pcolormesh(similarity_matrix,vmin=-1,vmax=1,cmap='YlGnBu')
        # fig.colorbar(meshreturn)
        # ax.set_xlabel('index',fontsize=18)
        # ax.tick_params(axis = 'both', labelsize = 18)
        # ax.set_ylabel('index',fontsize=18)
        # plt.tight_layout()
        plt.savefig(save_path, dpi = 50, bbox_inches='tight', pad_inches=0, transparent=True)
        plt.close(fig)
    return order


## condense condition
def get_condense_condition1(args, model, device, working_dir):
    '''
        获取condense的条件
        返回condense1_cond_num1
    '''
    with torch.no_grad():
        W_V_weight = model.decoder.layers[0].dec_self_attn.W_V.weight.data.cpu().numpy()
        W_1_weight = model.decoder.layers[0].pos_ffn.fc[0].weight.data.cpu().numpy()
        W_2_weight = model.decoder.layers[0].pos_ffn.fc[2].weight.data.cpu().numpy()
        W_proj = model.projection.weight.data.cpu().numpy()

    W_V_weight = W_V_weight.T
    W_1_weight = W_1_weight.T
    W_2_weight = W_2_weight.T
    W_proj = W_proj.T

    target_vector = W_V_weight[:, 0]
    target_vector = target_vector / np.linalg.norm(target_vector)
    W_a = np.matmul(target_vector, W_V_weight)

    W_b = np.matmul(W_1_weight, W_2_weight)
    tmp = np.matmul(W_b, W_proj)

    W_a_for_num1 = W_a.reshape(-1, 1)
    condense_cond_num1 = np.sum((W_a_for_num1 * tmp) > 0).item()

    return condense_cond_num1

def get_condense_condition2(args, model, device, working_dir):
    with torch.no_grad():
        W_V_weight = model.decoder.layers[0].dec_self_attn.W_V.weight.data.cpu().numpy()
        W_1_weight = model.decoder.layers[0].pos_ffn.fc[0].weight.data.cpu().numpy()
        W_2_weight = model.decoder.layers[0].pos_ffn.fc[2].weight.data.cpu().numpy()
        W_proj = model.projection.weight.data.cpu().numpy()

    W_V_weight = W_V_weight.T
    W_1_weight = W_1_weight.T
    W_2_weight = W_2_weight.T
    W_proj = W_proj.T

    W_b = np.matmul(W_1_weight, W_2_weight)

    diag_matrix = np.diag(W_proj[:, 0])
    tmp_matrix = np.matmul(W_b, diag_matrix)

    tmp_vector = tmp_matrix[:, 0]
    condense_cond_num2 = 0
    for i in range(tmp_matrix.shape[1]):
        if np.dot(tmp_vector, tmp_matrix[:, i]) > 0:
            condense_cond_num2 += 1
    
    return condense_cond_num2


def train_step(args, model, train_data_loader, optimizer, criterion, device, model0_params, working_dir, epoch, clip=1, scheduler=None):
    model.train()
    epoch_loss = 0
    total_samples = 0
    diffs = None
    loss_list = []

    
    for i, (dec_inputs, dec_outputs) in enumerate(train_data_loader):  
        optimizer.zero_grad()
        dec_inputs, dec_outputs = dec_inputs.to(device), dec_outputs.to(device)
        outputs, _ = model(dec_inputs)
        diff = calculate_F_norm(model, model0_params)
        diffs = concat_diff(diff, diffs)
        
        batch_size = dec_inputs.size(0)  # 获取当前批次的实际大小
        total_samples += batch_size
        
        if args.model == 'DNN' or args.model == 'DNN_averaged':
            loss = criterion(outputs.view(batch_size, args.vocab_size), dec_outputs[:,-1].view(-1))
        else:
            loss = criterion(outputs.view(batch_size, args.seq_len, 1)[:,-1,:], dec_outputs[:,-1].view(-1)) # 二分类修改
            # print(outputs.view(batch_size, args.seq_len, 1)[:,-1,:].shape)
            # print(dec_outputs[:,-1].view(-1).shape)

        if (i * 4) % len(train_data_loader) == 0 or (epoch == 0 and i <= 100 and i % 10 == 0):
            for name, module in model.named_modules():
                if isinstance(module, nn.Linear) and ('proj' not in name):
                    weight = module.weight.data.cpu().numpy()
                    cos_sim_matrix = cosine_similarity_array(weight)
                    plot_weight_heatmap_eigen(cos_sim_matrix, working_dir + f'/condense_heatmap/epoch{epoch}_{i}_{name}.png')
                # QKT = model.decoder.layers[0].dec_self_attn.W_Q.weight @ model.decoder.layers[0].dec_self_attn.W_K.weight.T
                # QKT = QKT.detach().cpu().numpy()
                # cos_sim_matrix = cosine_similarity_array(QKT)
                # plot_weight_heatmap_eigen(cos_sim_matrix, working_dir + f'/condense_heatmap/epoch{epoch}_{i}_QKT.png')
        loss_list.append(loss.item())
        epoch_loss += loss_list[-1] * batch_size  # 将损失乘以批次大小
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()
    
        if scheduler is not None:
            scheduler.step()
    return epoch_loss / total_samples, diffs, loss_list


def test_step(args, model, test_data_loader, criterion, device):
    model.eval()
    epoch_loss = 0
    total_samples = 0
    
    for i, (dec_inputs, dec_outputs) in enumerate(test_data_loader):
        dec_inputs, dec_outputs = dec_inputs.to(device), dec_outputs.to(device)
        outputs, _ = model(dec_inputs)
        
        batch_size = dec_inputs.size(0)  # 获取当前批次的实际大小
        total_samples += batch_size
        
        if args.model == 'DNN' or args.model == 'DNN_averaged':
            loss = criterion(outputs.view(batch_size, args.vocab_size), dec_outputs[:,-1].view(-1))
        else:
            # print(dec_outputs[:,-1], )
            # loss = criterion(outputs.view(batch_size, args.seq_len, args.vocab_size)[:,-1,:], dec_outputs[:,-1].view(-1))
            loss = criterion(outputs.view(batch_size, args.seq_len, 1)[:,-1,:], dec_outputs[:,-1].view(-1))

        
        epoch_loss += loss.item() * batch_size  # 将损失乘以批次大小
    
    return epoch_loss / total_samples  # 返回平均损失



# 批量预测
def last_word_acc(args, model, data_loader):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.eval()
    correct = 0
    total_samples = 0
    
    for i, (dec_inputs, dec_outputs) in enumerate(data_loader):
        dec_inputs, dec_outputs = dec_inputs.to(device), dec_outputs.to(device)
        outputs, _ = model(dec_inputs)
        
        batch_size = dec_inputs.size(0)  # 获取当前批次的实际大小
        total_samples += batch_size
        
        predicted_sign = torch.sign(outputs[:,-1].squeeze())  # 输出 -1 或 1
        predicted_sign = torch.where(predicted_sign == -1, torch.tensor(0).to(device), predicted_sign)  # 将0替换为1
        # print(predicted_sign, dec_outputs[:,-1])
        correct += (predicted_sign == dec_outputs[:,-1]).sum().item()

    return correct / total_samples


def get_accuracy(args, model, data_loader_group, train_percent, test_percent, my_logger):
    '''
        计算每类数据的acc，返回train_acc, test_acc, acc_list
    '''
    train_acc = 0
    test_acc = 0
    acc_list = []
    
    for i, data_name in enumerate(args.data_name):
        data_loader = data_loader_group[data_name]

        # 准确率
        tmp_acc = last_word_acc(args, model, data_loader)
        acc_list.append(tmp_acc)
        # print(tmp_acc)
        if args.data_train[i] == 1:
            train_acc += tmp_acc * args.data_percent[i] / train_percent
        else:
            test_acc += tmp_acc * args.data_percent[i] / test_percent

        my_logger.info(f'data type: {data_name} \t Acc: {tmp_acc}')
    # else:
    #     # for i, data_name in enumerate(args.data_name):
    #     data_name='43_xel'
    #     data_loader = data_loader_group[data_name]

    #     # 准确率
    #     tmp_acc = last_word_acc(args, model, data_loader)
    #     acc_list.append(tmp_acc)
    #     data_name='34_xel'
    #     data_loader = data_loader_group[data_name]
        
    #     # 准确率
    #     tmp_acc = last_word_acc(args, model, data_loader)
    #     acc_list.append(tmp_acc)
        # if args.data_train[i] == 1:
        #     train_acc += tmp_acc * args.data_percent[i] / train_percent
        # else:
        #     test_acc += tmp_acc * args.data_percent[i] / test_percent

    #     my_logger.info(f'data type: {data_name} \t Acc: {tmp_acc}')

    # # if args.target in ['composition_more_anchor', 'composition']:
    #     # data_loader = data_loader_group['43_xel']
    #     deviation_dict = last_word_devi(args, model, data_loader)
    #     my_logger.info("Deviation Distribution:")
    #     for deviation, prob in deviation_dict.items():
    #         my_logger.info(f"  deviation: {deviation} \t Acc: {prob:.4f}")
        


    return train_acc, test_acc, acc_list



def _get_loss_of_each_data(args, model, data_loader_group, criterion, device):
    '''
        计算data_train=0的每类数据的loss，返回每类数据的loss和总loss
        对于训练数据，因数据量大不便计算，直接返回0
    '''
    test_loss = 0
    total_samples = 0
    loss_list = []
    for i, data_name in enumerate(args.data_name):
        if args.data_train[i] == 0:
            data_loader = data_loader_group[data_name]
            tmp_loss = test_step(args, model, data_loader, criterion, device)
            loss_list.append(tmp_loss)

            total_samples += len(data_loader.dataset)
            test_loss += tmp_loss * len(data_loader.dataset)
        else:
            loss_list.append(0)
        
    test_loss = test_loss / total_samples

    return loss_list, test_loss






def train(args, datas, **kwargs):
    '''
    Required:
        args: 超参数字典
        datas: 所有类型的数据集构成的字典
    '''
    # 训练集
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # 训练集
    train_data_loader = get_train_data(args, datas)

    args.num_batches = len(train_data_loader)

    # 所有数据集对应的data_loader
    data_loader_group = get_data_loader_group(args, datas)

    my_logger = Log(f'{args.working_dir}/train_log.log')
    
    # 模型与参数量
    model = get_model(args, device, **kwargs)
    if args.dtype == 'float64':
        model = model.double()
    
    my_logger.info(f'Model: {model}')
    model0_params = {}
    with torch.no_grad():
        for name, params in model.named_parameters():
            if params.ndim > 1:
                model0_params[name] = params.clone()
            
    my_logger.info(f'Total parameters: {sum(p.numel() for p in model.parameters())}')

    criterion = nn.CrossEntropyLoss(ignore_index=0).to(device)
    # criterion = ExponentialLoss(ignore_index=None).to(device)
    
    optimizer, scheduler = get_optimizer(model, args, **kwargs)

    # 对data_percent进行归一化
    percent_list = np.array(args.data_percent)
    percent_list = percent_list / np.sum(percent_list)
    args.data_percent = percent_list.tolist()

    # 保存参数
    save_args = dict(vars(args))
    # 将kwargs中的参数也保存
    for key, value in kwargs.items():
        save_args[key] = value
    for data_name in args.data_name:  # 记录每个datasize
        save_args[f'data_size_{data_name}'] = len(datas[data_name])
    save_to_json_noindent(save_args, f'{args.working_dir}/config.json')

    # 保存训练数据
    np.savez(f'{args.working_dir}/data/datas.npz', **datas)
    os.makedirs(args.working_dir + '/condense_heatmap', exist_ok=True)
    # 保存源代码    
    for file in ['main.py', 'data.py', 'train.py', 'test.py', 'script_LTP.py']:
        shutil.copy(file, f'{args.working_dir}/src/{file}')
    for dir in ['utils', 'model', 'data_generator']:
        shutil.copytree(dir, f'{args.working_dir}/src/{dir}', dirs_exist_ok=True)    
    
    train_loss_his = []        # 训练集loss
    test_loss_his = []         # data_train=0的数据的总loss
    group_loss_his = []        # 每类数据的loss，其中训练数据的loss为0（因计算量过大且不是很有意义）

    acc_epoch_his = []    
    train_acc_his = []         # data_train=1的数据的总accuracy(训练集accuracy)
    test_acc_his = []          # data_train=0的数据的总accuracy
    group_acc_his = []         # 每类数据的accuracy
    train_loss_list_along_step = []
    lr_his = []
    condense1_condition_list = []
    condense2_condition_list = []

    # 计算train data和test data的比例
    train_percent, test_percent = 0, 0
    for i in range(len(args.data_name)):
        if args.data_train[i] == 1:
            train_percent += args.data_percent[i]
        else:
            test_percent += args.data_percent[i]

    print('training...')
    torch.save(model.state_dict(), f'{args.working_dir}/model/model_ini.pt')
    Diffs = None
    for epoch in range(args.n_epoch):
        # 计算accuracy并输出
        if epoch % args.print_acc_epoch == 0 or epoch == args.n_epoch-1:
            train_acc, test_acc, acc_list = get_accuracy(args, model, data_loader_group, train_percent, test_percent, my_logger)  
        
            acc_epoch_his.append(epoch)
            train_acc_his.append(train_acc)
            test_acc_his.append(test_acc)
            group_acc_his.append(acc_list)

        # 计算学习率
        lr_his.append(optimizer.param_groups[0]['lr'])

        # 训练并计算loss
        train_loss, diffs, tmp_train_loss_list = train_step(args, model, train_data_loader, optimizer, criterion, device, model0_params, args.working_dir, epoch, args.clip, scheduler=scheduler)
        tmp_loss_list, test_loss = _get_loss_of_each_data(args, model, data_loader_group, criterion, device)

        train_loss_his.append(train_loss)
        group_loss_his.append(tmp_loss_list)
        test_loss_his.append(test_loss)
        Diffs = concat_diff(diffs, Diffs)
        train_loss_list_along_step.extend(tmp_train_loss_list)
        # print(Diffs)
        # 输出信息
        if epoch % args.print_loss_epoch == 0:
            my_logger.info(f'Epoch: {epoch:<5}  Train Loss: {train_loss:.4e}  Test Loss: {test_loss:.4e}')

        # 保存模型
        if (epoch % args.save_model_epoch == 0) or epoch == args.n_epoch-1:
            torch.save(model.state_dict(), f'{args.working_dir}/model/model_{epoch}.pt')

        # 获取condense 条件
        condense_cond_num1 = get_condense_condition1(args, model, device, args.working_dir)
        condense_cond_num2 = get_condense_condition2(args, model, device, args.working_dir)

        condense1_condition_list.append(condense_cond_num1)
        condense2_condition_list.append(condense_cond_num2)

        # 保存loss, acc并更新图片
        if ((epoch % args.plot_loss_acc_epoch == 0) and (epoch != 0)) or (epoch == args.n_epoch-1):
            # 保存loss
            np.save(f'{args.working_dir}/loss/train_loss_his.npy', np.array(train_loss_his))
            np.save(f'{args.working_dir}/loss/test_loss_his.npy', np.array(test_loss_his))
            np.save(f'{args.working_dir}/loss/group_loss_his.npy', np.array(group_loss_his))
            np.save(f'{args.working_dir}/loss/acc_epoch_his.npy', np.array(acc_epoch_his))
            np.save(f'{args.working_dir}/loss/train_acc_his.npy', np.array(train_acc_his))
            np.save(f'{args.working_dir}/loss/test_acc_his.npy', np.array(test_acc_his))
            np.save(f'{args.working_dir}/loss/group_acc_his.npy', np.array(group_acc_his))
            np.save(f'{args.working_dir}/loss/train_loss_list_along_step.npy', np.array(train_loss_list_along_step))
            np.save(f'{args.working_dir}/loss/lr_his.npy', np.array(lr_his))
            np.save(f'{args.working_dir}/loss/condense1_condition_list.npy', np.array(condense1_condition_list))
            np.save(f'{args.working_dir}/loss/condense2_condition_list.npy', np.array(condense2_condition_list))

            # 绘制lr
            plot_lr(args.working_dir)

            # 绘制loss
            plot_loss(args.working_dir)

            # 绘制mask和unmask的acc
            plot_acc(args.working_dir)

            # 绘制具体某类数据的acc
            if np.sum(args.data_show) != 0:
                plot_loss_of_each_data(args.working_dir)
                plot_acc_of_each_data(args.working_dir)
        format_settings(fs = 10)
        fig, ax = plt.subplots()
        for name in Diffs.keys():
            ax.semilogy(Diffs[name], label = name)
        ax.legend(fontsize = 10)
        ax.set_xlabel('Step', fontsize = 10)
        ax.set_ylabel('Diff', fontsize = 10)
        fig.savefig(f'{args.working_dir}/diff.png')
        plt.close(fig)
        fig, ax = plt.subplots()
        ax.plot(train_loss_list_along_step)
        ax.legend(fontsize = 10)
        ax.set_xlabel('Step', fontsize = 10)
        ax.set_ylabel('Loss', fontsize = 10)
        fig.savefig(f'{args.working_dir}/train_loss_list_along_step.png')
        plt.close(fig)

        fig = plt.figure()
        ax = fig.add_subplot(111)
        ax.plot(condense1_condition_list, label = 'condense1_condition')
        ax.plot(condense2_condition_list, label = 'condense2_condition')
        ax.legend(fontsize = 10)
        ax.set_xlabel('Epoch', fontsize = 10)
        ax.set_ylabel('Condition', fontsize = 10)
        fig.savefig(f'{args.working_dir}/condense_condition.png')
        plt.close(fig)



    
    ('training finished!')



