import torch
import torch.nn as nn
import torch.optim
from torch.utils.data import TensorDataset, DataLoader
import numpy as np
import random
from torch import nn
import torch.nn.functional  as F
import math

import os
import copy

from data_utils import prepare_data
from model_utils import (
    get_model,
)
from config import *
import optim


device = 'cuda' if torch.cuda.is_available() else 'cpu'

def binary_class_loss(outputs, targets):
    return torch.log(torch.add(torch.exp(-outputs * targets), 1)).mean()

def binary_class_correct(outputs, targets):
    predicted = (outputs > 0).float() * 2 - 1
    correct_num = (predicted == targets).float().sum().item()
    return correct_num

# Training
def train(net, optimizer, trainloader, epoch, verbose=False):
    if verbose:
        print('Epoch: %d' % epoch)
    net.train()
    for batch_idx, (train_x, train_y) in enumerate(trainloader):
        train_x, train_y = train_x.to(device), train_y.to(device)
        total_samples = train_x.shape[0] # batch
        outputs = net(train_x)
        loss = binary_class_loss(outputs["out"], train_y)
        optimizer.zero_grad()
        loss.backward()
        correct_num = binary_class_correct(outputs["out"], train_y)
        # print(net.attn.Wq[0, 0], net.attn.Wq.grad[0, 0])
        optimizer.step()

        # NOTE: we only have one iteration per epoch
        attn_weights = outputs["attn_weights"] # (batch, seqlen, seqlen)
        attn_weights_p = outputs["attn_weights_p"]
        attn_weights_n = outputs["attn_weights_n"]
        packed = (attn_weights, attn_weights_p, attn_weights_n, outputs["out"])

    
    train_loss = loss.item()
    train_acc = correct_num / total_samples
        
    return train_loss, train_acc, packed


def test(net, testloader, epoch):
    net.eval()
    with torch.no_grad():
        for batch_idx, (test_x, test_y) in enumerate(testloader):
            test_x, test_y = test_x.to(device), test_y.to(device)
            total_samples = test_x.shape[0] # batch
            outputs = net(test_x)
            test_loss = binary_class_loss(outputs["out"], test_y).item()
            correct_num = binary_class_correct(outputs["out"], test_y)
            test_acc = correct_num / total_samples

    return test_loss, test_acc


def one_run(model, optimizer, trainloader, testloader, n_epoch, save_path=None, verbose=False):
    train_loss_values = []
    test_loss_values = []
    train_acc_values = []
    test_acc_values = []

    n_hidden = model.attn.dv // 2
    dk = model.attn.dk
    feature_learning_p = np.zeros((n_hidden, n_epoch))
    feature_learning_n = np.zeros((n_hidden, n_epoch))
    feature_wq_inner_product = np.zeros((dk, n_epoch))
    feature_wk_inner_product = np.zeros((dk, n_epoch))

    if n_context == 2:
        noise_memorization_p = np.zeros((n_hidden, n_train, n_epoch))
        noise_memorization_n = np.zeros((n_hidden, n_train, n_epoch))
        noise_wq_inner_product = np.zeros((dk, n_train, n_epoch))
        noise_wk_inner_product = np.zeros((dk, n_train, n_epoch))
    else:
        n_noise = int(n_context // 2)
        noise_memorization_p = np.zeros((n_noise, n_hidden, n_train, n_epoch))
        noise_memorization_n = np.zeros((n_noise, n_hidden, n_train, n_epoch))
        noise_wq_inner_product = np.zeros((n_noise, dk, n_train, n_epoch))
        noise_wk_inner_product = np.zeros((n_noise, dk, n_train, n_epoch))

    attn_weights_dynamics = np.zeros((n_context, n_context, n_train, n_epoch))
    function_output_dynamics = np.zeros((n_train, n_epoch))

    if save_path is not None:
        os.makedirs(save_path, exist_ok=True)
 
    for ep in range(n_epoch):
        model.eval()
        for batch_idx, (train_x, train_y) in enumerate(trainloader):
            train_x, train_y = train_x.to(device), train_y.to(device)
            feature_learning_p[:, ep] =  (torch.matmul(model.attn.Wv[:n_hidden], train_x[0, 0])).cpu().detach().numpy()
            feature_learning_n[:, ep] =  (torch.matmul(model.attn.Wv[n_hidden:], train_x[0, 0])).cpu().detach().numpy()
            feature_wq_inner_product[:, ep] =  (torch.matmul(model.attn.Wq, train_x[0, 0].T)).cpu().detach().numpy()
            feature_wk_inner_product[:, ep] =  (torch.matmul(model.attn.Wk, train_x[0, 0].T)).cpu().detach().numpy()
            
            if n_context == 2:
                noise_idx = int(n_context//2)
                noise_memorization_p[:,:, ep] =  (torch.matmul(model.attn.Wv[:n_hidden], train_x[:, noise_idx].T)).cpu().detach().numpy()
                noise_memorization_n[:,:, ep] =  (torch.matmul(model.attn.Wv[n_hidden:], train_x[:, noise_idx].T)).cpu().detach().numpy()
                noise_wq_inner_product[:,:, ep] =  (torch.matmul(model.attn.Wq, train_x[:, noise_idx].T)).cpu().detach().numpy()
                noise_wk_inner_product[:,:, ep] =  (torch.matmul(model.attn.Wk, train_x[:, noise_idx].T)).cpu().detach().numpy()
            else:
                noise_idx = torch.arange(int(n_context//2), n_context).tolist()
                # print(model.attn.Wv[:n_hidden].shape)
                # print(train_x[:, noise_idx].shape)
                noise_memorization_p[:,:,:, ep] =  (torch.matmul(train_x[:, noise_idx], model.attn.Wv[:n_hidden].T).permute(1, 2, 0)).cpu().detach().numpy()
                noise_memorization_n[:,:,:, ep] =  (torch.matmul(train_x[:, noise_idx], model.attn.Wv[n_hidden:].T).permute(1, 2, 0)).cpu().detach().numpy()
                noise_wq_inner_product[:,:,:, ep] =  (torch.matmul(train_x[:, noise_idx], model.attn.Wq.T).permute(1, 2, 0)).cpu().detach().numpy()
                noise_wk_inner_product[:,:,:, ep] =  (torch.matmul(train_x[:, noise_idx], model.attn.Wk.T).permute(1, 2, 0)).cpu().detach().numpy()
        model.train()
        
        train_loss, train_acc, packed = train(model, optimizer, trainloader, ep, verbose=verbose)
        test_loss, test_acc = test(model, testloader, ep)

        train_loss_values.append(train_loss)
        train_acc_values.append(train_acc)
        test_acc_values.append(test_acc)
        test_loss_values.append(test_loss)

        # add feature learning dynamics
        attn_weights, _, _, out = packed
        attn_weights_dynamics[:,:,:, ep] = attn_weights.permute(1, 2, 3, 0).cpu().detach().numpy()
        function_output_dynamics[:, ep] = out.cpu().detach().numpy()

        if verbose:
            print(f'[{ep+1}|{n_epoch}] train_loss={train_loss:0.5e} | train_correct={train_acc} | test_loss={test_loss:0.5e} | test_correct={test_acc}')
    
    state_dict = {
        "train_loss_values": torch.as_tensor(train_loss_values),
        "test_loss_values": torch.as_tensor(test_loss_values),
        "train_acc_values": torch.as_tensor(train_acc_values),
        "test_acc_values": torch.as_tensor(test_acc_values),
        "final_test_loss": test_loss_values[-1],
        "final_train_loss": train_loss_values[-1],
        "final_test_acc": test_acc_values[-1],
        "final_train_acc": train_acc_values[-1],
        "noise_memorization_p": noise_memorization_p,
        "noise_memorization_n": noise_memorization_n,
        "feature_learning_p": feature_learning_p,
        "feature_learning_n": feature_learning_n,
        "feature_wq_inner_product": feature_wq_inner_product,
        "noise_wq_inner_product": noise_wq_inner_product,
        "feature_wk_inner_product": feature_wk_inner_product,
        "noise_wk_inner_product": noise_wk_inner_product,
        "attn_weights_dynamics": attn_weights_dynamics,
        "function_output_dynamics": function_output_dynamics,
    }
    if save_path is not None:
        path = f"{save_path}/dynamics.pt"
        print(f"save to path: {path}")
        torch.save(state_dict, path)
    return state_dict


def main_worker():
    DEBUG = False
    seed_data = 219

    # data
    data_type = "one_direction_sparse" # fixed
    n_dim = 2000
    global n_context
    n_context = 2 # fixed
    feature_ratio = 0.5 # fixed
    positive_ratio = 0.5 # fixed
    sparsity_level = int(0.04 * n_dim)
    # sparsity_level = None
    # noise_level = 1.0 / math.sqrt(n_dim) # 0.02236
    noise_level = 2.0 / math.sqrt(sparsity_level) # 0.1
    # noise_level = 0.5 / math.sqrt(sparsity_level) # 0.1
    noise_level_list = np.array([1e-4, 1e-3, 5e-2, 2.0]) / math.sqrt(sparsity_level)
    snr = 1.0 / (math.sqrt(sparsity_level) * noise_level) # 1.0
    snr_list = [snr]
    ortho = False

    # training
    global n_train
    global n_test
    global n_epoch
    n_train = int(0.01 * n_dim)
    n_test = int(500)
    n_epoch = 2000
    other_setting = f"tr{n_train}_te{n_test}_ep{n_epoch}"

    data_config_dict = dict(
        setting=data_type, 
        n_train=n_train,
        n_test=n_test,
        n_context=n_context, 
        n_dim=n_dim,
        # seed_data=seed, NOTE: set it when created
        feature_ratio=feature_ratio,
        positive_ratio=positive_ratio,
        sparsity_level=sparsity_level,
        noise_level=noise_level,
        snr=snr,
        ortho=ortho,
        overlap=True,
    )

    # model
    n_hidden = 20
    act_q = 3
    act_linear = False
    act_linear_list = [True]
    model_cls = "AttnBinary" # variable
    model_cls_list = ["AttnBinary", "AttnBinaryV"]
    attn_type = "softmax" # variable
    fixed_qk = False # variable
    fixed_qk_list = [False]
    seeds = [0]

    model_config_dict = dict(
        model_cls=model_cls,
        hidden_size=n_dim,
        act_q=act_q,
        act_linear=act_linear,
        dk=int(n_dim * 0.05),
        dv=int(n_dim * 0.01),
        num_attention_heads=1,
        attn_type=attn_type,
        fixed_qk=fixed_qk,
        fixed_v=False,
        attn_scaling_type="1",
    )
    
    # optim
    optim_name = "gd" # variable
    # optim_name_list = ["mysign"]
    optim_name_list = ['gd']
    # optim_name_list = ["adam", "gd", "adam0", "adam5"]
    # optim_name_list = ["mysign", "adam", "gd", "adam0", "adam5"]
    optim_cls_dict = dict(
        gd=torch.optim.SGD,
        gdm=torch.optim.SGD,
        adam=torch.optim.Adam,
        sign=torch.optim.Adam,
        mysign=optim.signGD,
        adam0=torch.optim.Adam,
        adam5=torch.optim.Adam,
        normgd=optim.normalizedGD,
    )
    optim_args_dict=dict(
        gd=dict(lr=1.0),
        gdm=dict(lr=0.05, momentum=0.9),
        adam=dict(lr=1e-4, eps=1e-15),
        sign=dict(lr=1e-3, betas=(0.0, 0.0)),
        mysign=dict(lr=1e-4),
        adam0=dict(lr=1e-3, betas=(0.0, 0.999), eps=1e-15),
        adam5=dict(lr=1e-3, betas=(0.5, 0.999), eps=1e-15),
        normgd=dict(lr=1e-3),
    )
    optim_lr_dict=dict(
        gd=[1e-4],
        gdm=[1e-1],
        adam=[1e-3, 1e-4],
        mysign=[1e-3, 1e-4, 1e-7],
        adam0=[1e-3, 1e-4],
        adam5=[1e-3, 1e-4],
        normgd=[1e-1, 1e-2],
        sign=[1e-3, 1e-2],
    )

    for noise_level in noise_level_list:
        snr = 1.0 / (math.sqrt(sparsity_level) * noise_level) # 1.0
        data_config_dict["noise_level"] = noise_level
        data_config_dict["snr"] = snr
        # all data settings done
        
        for model_cls in model_cls_list:
            model_config_dict["model_cls"] = model_cls
            for act_linear in act_linear_list:
                model_config_dict["act_linear"] = act_linear
                if act_linear and model_cls == "AttnBinaryV":
                    continue
                # all model settings done

                for optim_name in optim_name_list:
                    optim_cls = optim_cls_dict[optim_name]
                    optim_args = optim_args_dict[optim_name].copy()
                    optim_lr_list = optim_lr_dict[optim_name]
                    for lr in optim_lr_list:
                        optim_args["lr"] = lr
                        # all optim settings done

                        for seed in seeds:
                            # this seed is for both data and model

                            # Data
                            print('==> Preparing data..')
                            train_x, train_y, test_x, test_y, _ = prepare_data(
                                seed_data=seed,
                                **data_config_dict,
                            )
                            trainset = TensorDataset(train_x, train_y)
                            trainloader = torch.utils.data.DataLoader(
                                trainset, batch_size=n_train, shuffle=False)

                            testset = TensorDataset(test_x, test_y)
                            testloader = torch.utils.data.DataLoader(
                                testset, batch_size=n_test, shuffle=False)

                            # one run 
                            model_config = ModelConfig(**model_config_dict)
                            model = get_model(model_config)
                            model.to(device)
                            optimizer = optim_cls(model.parameters(), **optim_args)
                            state_dict = one_run(model, optimizer, trainloader, testloader, n_epoch, None, verbose=DEBUG)

                            # save the state_dict
                            if not DEBUG:
                                if n_context == 2:
                                    data_setting = f"ods_d{n_dim}_sparse{sparsity_level}_noise{noise_level}_ortho{ortho}_seed{seed}"
                                else:
                                    data_setting = f"ods_d{n_dim}_l{n_context}_fr{feature_ratio}_sparse{sparsity_level}_noise{noise_level}_ortho{ortho}_seed{seed}"
                                model_act = "linearact" if model_config.act_linear else f"{model_config.act_q}-relu"
                                model_fixq = "q" if model_config.fixed_qk else ""
                                model_fixk = "k" if model_config.fixed_qk else ""
                                model_fixv = "v" if model_config.fixed_v else ""
                                model_fix = "none" if (not model_config.fixed_qk and not model_config.fixed_v) else model_fixq + model_fixk + model_fixv
                                model_setting = f"{model.name}_d{n_dim}dk{model_config.dk}dv{model_config.dv}_{model_config.attn_type}attn-{model_config.attn_scaling_type}_{model_act}_fix{model_fix}_seed{seed}"
                                optim_setting = f"{optim_name}_lr{optim_args['lr']}"
                                
                                state_dict_dir = f"results_ods_finer/{other_setting}/{data_setting}"
                                state_dict_path = f"{state_dict_dir}/{optim_setting}_{model_setting}"
                                print(state_dict_path)
                                os.makedirs(state_dict_dir, exist_ok=True)
                                torch.save(state_dict, state_dict_path)
                    print(optim_name)

if __name__ == "__main__":
    main_worker()
                    