import torch
import torch.nn as nn
from torchvision import datasets, transforms
from matplotlib import pyplot as plt
import numpy as np
from einops import rearrange

from vit_pytorch import ViT

import os
import sys
import time
import math

import optim
from model_utils import get_model
from config import ModelConfig

device = 'cuda' if torch.cuda.is_available() else 'cpu'
criterion = nn.CrossEntropyLoss()
torch.manual_seed(2024)

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])


trainset = datasets.MNIST(
    root='./data', train=True, download=True, transform=transform)
testset = datasets.MNIST(
    root='./data', train=False, download=True, transform=transform)
# print(trainset[0])
# trainloader = torch.utils.data.DataLoader(
#     trainset, batch_size=batchsize, shuffle=True, num_workers=1)
# testloader = torch.utils.data.DataLoader(
#     testset, batch_size=1000, shuffle=False, num_workers=1)
mnist_rawdata_train = trainset.data.clone()
mnist_rawdata_test = testset.data.clone()

def filter_dataset(dataset, class1 = 3, class2 = 7):
    target_class1 = class1
    indices1 = torch.where(dataset.targets == target_class1)[0]
    target_class2 = class2
    indices2 = torch.where(dataset.targets == target_class2)[0]
    indices = torch.cat([indices1, indices2], dim=0)

    def normalize(tensor):
        # return (tensor / 255.)
        return (tensor / 255. - 0.1307) / 0.3081
    
    filtered_data = dataset.data[indices]
    filtered_data = normalize(filtered_data)
    filtered_targets = dataset.targets[indices]
    filtered_targets = torch.where(filtered_targets == target_class1, torch.tensor(1), filtered_targets)
    filtered_targets = torch.where(filtered_targets == target_class2, torch.tensor(-1), filtered_targets)

    return filtered_data, filtered_targets

trainset_filtered_x, trainset_filtered_y = filter_dataset(trainset)
testset_filtered_x, testset_filtered_y = filter_dataset(testset)
train_num = trainset_filtered_x.shape[0]
test_num = testset_filtered_x.shape[0]
print(train_num, test_num)




def binary_class_loss(outputs, targets):
    return torch.log(torch.add(torch.exp(-outputs * targets), 1)).mean()

def binary_class_correct(outputs, targets):
    predicted = (outputs > 0).float() * 2 - 1
    correct_num = (predicted == targets).float().sum().item()
    return correct_num

# Training
def train(net, optimizer, trainloader, epoch, verbose=False):
    if verbose:
        print('Epoch: %d' % epoch)
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (train_x, train_y) in enumerate(trainloader):
        train_x, train_y = train_x.to(device), train_y.to(device)
        outputs = net(train_x)
        loss = binary_class_loss(outputs["out"], train_y)
        optimizer.zero_grad()
        loss.backward()
        correct_num = binary_class_correct(outputs["out"], train_y)
        # print(net.attn.Wq[0, 0], net.attn.Wq.grad[0, 0])
        optimizer.step()
        
        # print(train_x.shape, train_y.shape, outputs["out"].shape)
        # print(train_x, train_y, outputs["out"])
        # print(loss.item())

        train_loss += loss.item() * train_x.shape[0]
        correct += correct_num
        total += train_x.shape[0]

    attn_weights = outputs["attn_weights"] # (batch, seqlen, seqlen)
    attn_weights_p = outputs["attn_weights_p"]
    attn_weights_n = outputs["attn_weights_n"]
    packed = (attn_weights, attn_weights_p, attn_weights_n, outputs["out"])

    train_loss = train_loss / total
    train_acc = correct / total * 100
        
    return train_loss, train_acc, packed


def test(net, testloader, epoch):
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (test_x, test_y) in enumerate(testloader):
            test_x, test_y = test_x.to(device), test_y.to(device)
            outputs = net(test_x)
            test_loss += binary_class_loss(outputs["out"], test_y).item() * test_x.shape[0]
            correct += binary_class_correct(outputs["out"], test_y)
            total += test_x.shape[0]

    test_loss = test_loss / total
    test_acc = correct / total * 100
    
    return test_loss, test_acc


def create_model():
    net = ViT(
        image_size = 28,
        patch_size = 7,
        num_classes = 10,
        dim = 128,
        depth = 1,
        heads = 2,
        mlp_dim = 256,
        dropout = 0.0,
        emb_dropout = 0.0,
        channels=1,
    )
    return net.to(device)

def one_run(model, optimizer, trainloader, testloader, n_epoch, save_path=None, verbose=False):
    train_loss_values = []
    test_loss_values = []
    train_acc_values = []
    test_acc_values = []

    # feature learning log start
    # n_hidden = model.attn.dv // 2
    # dk = model.attn.dk
    # feature_learning_p = np.zeros((n_hidden, n_epoch))
    # feature_learning_n = np.zeros((n_hidden, n_epoch))
    # feature_wq_inner_product = np.zeros((dk, n_epoch))
    # feature_wk_inner_product = np.zeros((dk, n_epoch))

    # n_context = 16
    # n_train = 3000
    # n_noise = int(n_context // 2)
    # noise_memorization_p = np.zeros((n_noise, n_hidden, n_train, n_epoch))
    # noise_memorization_n = np.zeros((n_noise, n_hidden, n_train, n_epoch))
    # noise_wq_inner_product = np.zeros((n_noise, dk, n_train, n_epoch))
    # noise_wk_inner_product = np.zeros((n_noise, dk, n_train, n_epoch))

    # attn_weights_dynamics = np.zeros((n_context, n_context, n_train, n_epoch))
    # function_output_dynamics = np.zeros((n_train, n_epoch))
    # feature learning log end

    test_loss, test_acc = test(model, testloader, -1) 
    print(f"Before training: test_loss={test_loss:0.3e}| test_acc={test_acc}")
    for ep in range(n_epoch):
        # feature learning log start
        # model.eval()

        # train_x = train_final_x.to(device)
        # train_x = rearrange(train_x, 'b (h p1) (w p2) -> b (h w) (p1 p2)', p1 = 7, p2 = 7)
        
        # feature_learning_p[:, ep] =  (torch.matmul(model.attn.Wv[:n_hidden], train_x[0, 0])).cpu().detach().numpy()
        # feature_learning_n[:, ep] =  (torch.matmul(model.attn.Wv[n_hidden:], train_x[0, 0])).cpu().detach().numpy()
        # feature_wq_inner_product[:, ep] =  (torch.matmul(model.attn.Wq, train_x[0, 0].T)).cpu().detach().numpy()
        # feature_wk_inner_product[:, ep] =  (torch.matmul(model.attn.Wk, train_x[0, 0].T)).cpu().detach().numpy()
        
        # noise_idx = torch.arange(int(n_context//2), n_context).tolist()
        # # print(model.attn.Wv[:n_hidden].shape)
        # # print(train_x[:, noise_idx].shape)
        # noise_memorization_p[:,:,:, ep] =  (torch.matmul(train_x[:, noise_idx], model.attn.Wv[:n_hidden].T).permute(1, 2, 0)).cpu().detach().numpy()
        # noise_memorization_n[:,:,:, ep] =  (torch.matmul(train_x[:, noise_idx], model.attn.Wv[n_hidden:].T).permute(1, 2, 0)).cpu().detach().numpy()
        # noise_wq_inner_product[:,:,:, ep] =  (torch.matmul(train_x[:, noise_idx], model.attn.Wq.T).permute(1, 2, 0)).cpu().detach().numpy()
        # noise_wk_inner_product[:,:,:, ep] =  (torch.matmul(train_x[:, noise_idx], model.attn.Wk.T).permute(1, 2, 0)).cpu().detach().numpy()
        # model.train()
        # feature learning log end

        train_loss, train_acc, packed = train(model, optimizer, trainloader, ep, verbose=False)
        test_loss, test_acc = test(model, testloader, ep)
        if verbose:
            print(f'[{ep+1}|{n_epoch}] train_loss={train_loss:0.5e} | train_correct={train_acc} | test_loss={test_loss:0.5e} | test_correct={test_acc}')

        train_loss_values.append(train_loss)
        train_acc_values.append(train_acc)
        test_acc_values.append(test_acc)
        test_loss_values.append(test_loss)

        # feature learning log start
        # add feature learning dynamics
        # attn_weights, _, _, out = packed
        # attn_weights_dynamics[:,:,:, ep] = attn_weights.permute(1, 2, 3, 0).cpu().detach().numpy()
        # function_output_dynamics[:, ep] = out.cpu().detach().numpy()
        # feature learning log end

    state_dict = {
        "train_loss_values": torch.as_tensor(train_loss_values),
        "test_loss_values": torch.as_tensor(test_loss_values),
        "train_acc_values": torch.as_tensor(train_acc_values),
        "test_acc_values": torch.as_tensor(test_acc_values),
        "final_test_loss": test_loss_values[-1],
        "final_train_loss": train_loss_values[-1],
        "final_test_acc": test_acc_values[-1],
        "final_train_acc": train_acc_values[-1],
        # "noise_memorization_p": noise_memorization_p,
    #     "noise_memorization_n": noise_memorization_n,
    #     "feature_learning_p": feature_learning_p,
    #     "feature_learning_n": feature_learning_n,
    #     "feature_wq_inner_product": feature_wq_inner_product,
    #     "noise_wq_inner_product": noise_wq_inner_product,
    #     "feature_wk_inner_product": feature_wk_inner_product,
    #     "noise_wk_inner_product": noise_wk_inner_product,
    #     "attn_weights_dynamics": attn_weights_dynamics,
    #     "function_output_dynamics": function_output_dynamics,
    #     "train_y": train_final_y,
    }
    return (
        train_loss_values,
        train_acc_values,
        test_acc_values,
        test_loss_values,
        state_dict,
    )
    

batchsize = 128
n_epoch = 200
base_datanum = 128
dk = 10
dv = 10
scale_type_dft = 2
# final_train_loss = []
# final_test_loss = []
# final_train_acc = []
# final_test_acc = []

optim_name = "gd" # variable
# optim_name_list = ["mysign", "gd", "adam"]
# optim_name_list = ["mysign", "gd"]
# optim_name_list = ['gd']
# optim_name_list = ['mysign']
# optim_name_list = ['adam']
# optim_name_list = ['mysign', 'adam0']
optim_name_list = ['adam', 'adam5']
# optim_name_list = ["adam", "gd", "adam0", "adam5"]
# optim_name_list = ["mysign", "adam", "gd", "adam0", "adam5"]
optim_cls_dict = dict(
    gd=torch.optim.SGD,
    gdm=torch.optim.SGD,
    adam=torch.optim.Adam,
    sign=torch.optim.Adam,
    mysign=optim.signGD,
    adam0=torch.optim.Adam,
    adam5=torch.optim.Adam,
    normgd=optim.normalizedGD,
)
optim_args_dict=dict(
    gd=dict(lr=1e-2),
    gdm=dict(lr=0.05, momentum=0.9),
    adam=dict(lr=1e-4, eps=1e-15),
    sign=dict(lr=1e-3, betas=(0.0, 0.0)),
    mysign=dict(lr=1e-4),
    adam0=dict(lr=1e-3, betas=(0.0, 0.999), eps=1e-15),
    adam5=dict(lr=1e-3, betas=(0.5, 0.999), eps=1e-15),
    normgd=dict(lr=1e-3),
)
optim_lr_dict=dict(
    gd=[1e-1, 5e-1],
    gdm=[1e-1],
    adam=[1e-3, 5e-4],
    mysign=[1e-3],
    adam0=[1e-3, 5e-4],
    adam5=[1e-3, 5e-4],
    normgd=[1e-1, 1e-2],
    sign=[1e-3, 1e-2],
)

for optim_name in optim_name_list:
    optim_cls = optim_cls_dict[optim_name]
    optim_args = optim_args_dict[optim_name].copy()
    optim_lr_list = optim_lr_dict[optim_name]
    for lr in optim_lr_list:
    # optim setting

        optim_args["lr"] = lr
        print(optim_cls, lr)
        filename = f"./results_mnist/{optim_name}_{lr}_dk{dk}dv{dv}_base{base_datanum}scaledft{scale_type_dft}_epoch{n_epoch}"

        final_train_loss = torch.zeros((10, 10))
        final_test_loss = torch.zeros((10, 10))
        final_train_acc = torch.zeros((10, 10))
        final_test_acc = torch.zeros((10, 10))

        if os.path.exists(filename):
            state_dict = torch.load(filename)
            print(f"load {filename} successfully")
            final_train_loss = state_dict["final_train_loss"]
            final_test_loss = state_dict["final_test_loss"]
            final_train_acc = state_dict["final_train_acc"]
            final_test_acc = state_dict["final_test_acc"]
        
        for scale_type in [scale_type_dft]:
            for s in range(10):
                for t in range(10):
                # data setting

                    for seed in [0, 1, 2]:
                        if t != 0:
                            continue
                        # if s != 9:
                        #     continue

                        torch.manual_seed(seed)
                        if scale_type != 0:
                        # data generation 1, use sigma to decrease feature with constant noise
                            sigma = 0.1 * (s+1) # [0.01, 1]
                            mult = 1.0 if scale_type == 1 else (1-sigma)
                            A = torch.randn(train_num,28,28) * mult
                            A_1 = torch.randn(test_num,28,28) * mult
                            B = torch.zeros([train_num,14,14])
                            A[0:train_num, 7:21, 7:21] = B
                            B_1 = torch.zeros([test_num,14,14])
                            A_1[0:test_num, 7:21, 7:21] = B_1
                            train_noisy_filtered_x = A   + sigma*trainset_filtered_x
                            test_noisy_filtered_x  = A_1 + sigma*testset_filtered_x
                            # train_noisy_filtered_x = A
                            # train_noisy_filtered_x[0:train_num, 7:21, 7:21] = sigma*trainset_filtered_x[0:train_num, 7:21, 7:21]
                            # test_noisy_filtered_x  = A_1
                            # test_noisy_filtered_x[0:test_num, 7:21, 7:21] = sigma*testset_filtered_x[0:test_num, 7:21, 7:21]

                        else:
                        # data generation 1, use sigma to define size of noise with constant feature
                            sigma = s # [1, 20]
                            A = torch.randn(train_num,28,28) * sigma
                            A_1 = torch.randn(test_num,28,28) * sigma
                            B = torch.zeros([train_num,14,14])
                            A[0:train_num, 7:21, 7:21] = B
                            B_1 = torch.zeros([test_num,14,14])
                            A_1[0:test_num, 7:21, 7:21] = B_1
                            train_noisy_filtered_x = A   + sigma*trainset_filtered_x
                            test_noisy_filtered_x  = A_1 + sigma*testset_filtered_x
                            # train_noisy_filtered_x = A
                            # train_noisy_filtered_x[0:train_num, 7:21, 7:21] = sigma*trainset_filtered_x[0:train_num, 7:21, 7:21]
                            # test_noisy_filtered_x  = A_1
                            # test_noisy_filtered_x[0:test_num, 7:21, 7:21] = sigma*testset_filtered_x[0:test_num, 7:21, 7:21]
                    
                        print(f"sigma = {sigma}; t = {t}; scale_type = {scale_type}")
                        shuffled_indices = torch.randperm(train_num)
                        train_s_n_f_x = train_noisy_filtered_x.index_select(0, shuffled_indices) # shuffled, noisy, filtered
                        train_s_f_y = trainset_filtered_y.index_select(0, shuffled_indices) # shuffled, noisy, filtered
                        train_final_x = train_s_n_f_x[:base_datanum+t*500]
                        train_final_y = train_s_f_y[:base_datanum+t*500]
                        print(train_final_x.shape, train_final_y.shape)

                        shuffled_indices = torch.randperm(test_num)
                        test_s_n_f_x = test_noisy_filtered_x.index_select(0, shuffled_indices) # shuffled, noisy, filtered
                        test_s_f_y = testset_filtered_y.index_select(0, shuffled_indices) # shuffled, noisy, filtered
                        test_final_x = test_s_n_f_x[:500]
                        test_final_y = test_s_f_y[:500]

                        trainset_nf = torch.utils.data.TensorDataset(train_final_x, train_final_y)
                        testset_nf  = torch.utils.data.TensorDataset(test_final_x,  test_final_y)

                        # model setting fixed
                
                        # create data
                        trainloader = torch.utils.data.DataLoader(
                            trainset_nf, batch_size=batchsize, shuffle=True, num_workers=4)
                        testloader = torch.utils.data.DataLoader(
                            testset_nf, batch_size=test_num, shuffle=False, num_workers=4)
                        # create model
                        n_dim = 49
                        act_q = 3
                        act_linear = True
                        model_cls = "AttnBinary" # variable
                        attn_type = "softmax" # variable
                        fixed_qk = False # variable
                        model_config_dict = dict(
                            model_cls=model_cls,
                            hidden_size=n_dim,
                            act_q=act_q,
                            act_linear=act_linear,
                            dk=int(dk),
                            dv=int(dv),
                            num_attention_heads=1,
                            attn_type=attn_type,
                            fixed_qk=fixed_qk,
                            fixed_v=False,
                            attn_scaling_type="1",
                            patch_size=7,
                            seed_model=seed,
                        )
                        model_config = ModelConfig(**model_config_dict)
                        model = get_model(model_config)
                        model.to(device)
                        # print(model)
                        # create optimizer
                        optimizer = optim_cls(model.parameters(), **optim_args)
                        # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=n_epoch)
            
                        (
                            train_loss_values,
                            train_acc_values,
                            test_acc_values,
                            test_loss_values,
                            state_dict,
                        ) = one_run(model, optimizer, trainloader, testloader, n_epoch, verbose=True)
                        # torch.save(state_dict, f"{filename}_sigma{sigma}_t{t}")
                        filename_run = f"./results_mnist/runs/{optim_name}_{lr}_dk{dk}dv{dv}_base{base_datanum}scaledft{scale_type_dft}_epoch{n_epoch}_sigma{sigma}seed{seed}"
                        torch.save(state_dict, filename_run)
                        # final_train_loss.append((optim_cls, lr, scale_type, sigma, train_loss_values[-1]))
                        # final_test_loss.append((optim_cls, lr, scale_type, sigma, test_loss_values[-1]))
                        # final_train_acc.append((optim_cls, lr, scale_type, sigma, train_acc_values[-1]))
                        # final_test_acc.append((optim_cls, lr, scale_type, sigma, test_acc_values[-1]))
                        final_train_loss[s, t] += train_loss_values[-1]
                        final_test_loss[s, t] += test_loss_values[-1]
                        final_train_acc[s, t] += train_acc_values[-1]
                        final_test_acc[s, t] += test_acc_values[-1]

            state_dict_loss = {
                "final_train_loss": final_train_loss / 3,
                "final_test_loss": final_test_loss / 3,
                "final_train_acc": final_train_acc / 3,
                "final_test_acc": final_test_acc / 3,
            }
            torch.save(state_dict_loss, f"{filename}")
            print(f"save {filename} successfully")

# print(f"final_train_loss: {final_train_loss}")
# print(f"final_train_acc: {final_train_acc}")
# print(f"final_test_loss: {final_test_loss}")
# print(f"final_test_acc: {final_test_acc}")