import argparse
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.utils.data
from torch.autograd import Variable
import os
import timeit

from dataset import get_data, process_data
from printer import Printer
from util import exp_set_up, print_arg, set_random, accuracy, AverageMeter
from model import Classification_Net, NTK, Hinge_Loss
import model
import yaml
import shutil
import sys
import matplotlib.pyplot as plt
from metric import MDS_plot, TSNE_plot

train_losses = []
test_losses = []
train_acces = []
test_acces = []
loss_path = None
figure_path = None

good_weights = None
good_weights_repeat = None
good_weights_exp = None

real_dataset = None
cos_simlarity_list = []


def train(args):
    global loss_path, figure_path, good_weights, good_weights_repeat, good_weights_exp, real_dataset
    config_file = args.config
    cf = yaml.load(open(config_file, 'r'))
    cf["name"] = os.path.basename(config_file)[:-5]

    loss_path = exp_set_up(cf)
    figure_path = cf["test_path"]

    shutil.copyfile(config_file, os.path.join(
        cf["exp_path"], cf["name"]+".yaml"))
    log_file = open(cf["log_path"], 'w+')
    printer = Printer(sys.stdout, log_file).open()
    print_arg(cf)

    dim_full = cf["data"]["dim_full"]
    width = cf["model"]["hidden_dim"]
    label_num = cf["model"]["label_num"]
    cudnn.benchmark = True

    USE_CUDA = False
    # Get data
    train_data, test_data, Q, W, label_weight = get_data(cf)

    if cf["data"]["dataset"] in ["MNIST", "CIFAR10", "SVHN"]:
        real_dataset = True
    else:
        real_dataset = False

    if not real_dataset:
        cf["teacherSeed"] = 1128  # fix seed
        print("Teacher Seed: ", cf["teacherSeed"])
        set_random(cf["teacherSeed"])
        teacher_width = cf["data"]["feature_dim"] 

        teacher = Classification_Net(dim_full,  teacher_width, label_num)
        teacher.apply(model.weights_init)

        for name, param in teacher.named_parameters():
            if "fc1" in name and "weight" in name:
                param = W
                good_weights = torch.sum(W[:cf["data"]["effective_dim"]], dim = 0)
                good_weights = good_weights / torch.norm(good_weights)
                good_weights_repeat = good_weights.repeat(
                    cf["model"]["hidden_dim"], 1)
                good_weights_exp = torch.unsqueeze(good_weights, dim=0)
    else:
        teacher = None

    cf["manualSeed"] = 1126  # fix seed
    print("Random Seed: ", cf["manualSeed"])
    set_random(cf["manualSeed"])
    if cf["model"]["type"] == "ntk":
        net = NTK(dim_full,  width, label_num)
        net.init_weight_grad()
    else:
        net = Classification_Net(dim_full,  width, label_num)
        if real_dataset:
            net.apply(model.weights_init)
        else:
            net.init_weight_grad(cf["data"]["effective_dim"])

        for name, param in net.named_parameters():
            if "fc1" in name:
                if cf["model"]["type"] == "fix":
                    param.requires_grad = False
                else:
                    param.requires_grad = True
            if "fc2" in name:
                    param.requires_grad = True

    if USE_CUDA:
        net.cuda()

    if cf["train"]["loss"] == "cross":
        if label_weight is not None:
            criterion = nn.CrossEntropyLoss(weight=label_weight)
        else:
            criterion = nn.CrossEntropyLoss()
    elif cf["train"]["loss"] == "hinge":
        if label_weight is not None:
            criterion = Hinge_Loss(weight=label_weight)
        else:
            criterion = Hinge_Loss()

    time = timeit.default_timer()
    num_epochs = cf["train"]["niter"]

    trainloader = torch.utils.data.DataLoader(
        train_data, batch_size=cf["data"]["data_size"], shuffle=True, pin_memory=True, num_workers=cf["data"]["num_workers"])
    testloader = torch.utils.data.DataLoader(
        test_data, batch_size=cf["data"]["data_size"], shuffle=True, pin_memory=True, num_workers=cf["data"]["num_workers"])
    
    test_avg_loss,  test_avg_acc = test_model(
        cf, 0, net, testloader, criterion, Q, USE_CUDA, teacher)
    print('[%3d/%3d] train_avg_loss %.4f | train_avg_acc %.4f |  test_avg_loss %.4f | test_avg_acc %.4f |'
          % (0, cf["train"]["niter"], test_avg_loss, test_avg_acc, test_avg_loss, test_avg_acc))

    train_losses.append(test_avg_loss)
    test_losses.append(test_avg_loss)

    train_acces.append(test_avg_acc)
    test_acces.append(test_avg_acc)

    # =========================
    for epoch in range(1, 2):
        train_avg_loss, train_avg_acc = train_model(cf, epoch, trainloader, None, net, criterion, time, USE_CUDA, teacher)
        test_avg_loss,  test_avg_acc = test_model(
            cf, epoch, net, testloader, criterion, Q, USE_CUDA, teacher)
        print('[%3d/%3d] train_avg_loss %.4f | train_avg_acc %.4f |  test_avg_loss %.4f | test_avg_acc %.4f |'
              % (epoch, cf["train"]["niter"], train_avg_loss, train_avg_acc, test_avg_loss, test_avg_acc))
        train_losses.append(train_avg_loss)
        test_losses.append(test_avg_loss)

        train_acces.append(train_avg_acc)
        test_acces.append(test_avg_acc)

        torch.save(net.state_dict(), '{0}/net_epoch_{1}.pth'.format(cf["checkpoint_path"], epoch))

    if cf["model"]["type"] != "ntk":
        for name, param in net.named_parameters():
            if "fc1" in name:
                if cf["model"]["type"] == "one":
                    param.requires_grad = False

    for epoch in range(2, 3):
        train_avg_loss, train_avg_acc = train_model(
            cf, epoch, trainloader, None, net, criterion, time, USE_CUDA, teacher)
        test_avg_loss,  test_avg_acc = test_model(
            cf, epoch, net, testloader, criterion, Q, USE_CUDA, teacher)
        print('[%3d/%3d] train_avg_loss %.4f | train_avg_acc %.4f |  test_avg_loss %.4f | test_avg_acc %.4f |'
              % (epoch, cf["train"]["niter"], train_avg_loss, train_avg_acc, test_avg_loss, test_avg_acc))
        train_losses.append(train_avg_loss)
        test_losses.append(test_avg_loss)

        train_acces.append(train_avg_acc)
        test_acces.append(test_avg_acc)

        torch.save(net.state_dict(),
                   '{0}/net_epoch_{1}.pth'.format(cf["checkpoint_path"], epoch))

    if cf["model"]["type"] != "ntk":
        for name, param in net.named_parameters():
            if "fc1" in name:
                if cf["model"]["type"] == "two":
                    param.requires_grad = False

    if cf["train"]["decay"] == "l1":
        weight_decay = 0
    else:
        weight_decay = cf["train"]["weight_decay"]

    if cf["train"]["optimizer"] == "adam":
        optimizer = torch.optim.Adam(
            net.parameters(), lr=cf["train"]["lr"], weight_decay=weight_decay)
    elif cf["train"]["optimizer"] == "sgd":
        optimizer = torch.optim.SGD(
            net.parameters(), lr=cf["train"]["lr"], momentum=0.95, weight_decay=weight_decay)  # TODO
    else:
        assert(False)

    trainloader = torch.utils.data.DataLoader(
        train_data, batch_size=cf["data"]["batch_size"], shuffle=True, pin_memory=True, num_workers=cf["data"]["num_workers"])
    testloader = torch.utils.data.DataLoader(
        test_data, batch_size=cf["data"]["batch_size"], shuffle=True, pin_memory=True, num_workers=cf["data"]["num_workers"])
    for epoch in range(3, num_epochs + 1):
        train_avg_loss, train_avg_acc = train_model(cf, epoch, trainloader, optimizer, net, criterion, time, USE_CUDA, teacher)
        test_avg_loss,  test_avg_acc = test_model(
            cf, epoch, net, testloader, criterion, Q, USE_CUDA, teacher)
        train_losses.append(train_avg_loss)
        test_losses.append(test_avg_loss)

        train_acces.append(train_avg_acc)
        test_acces.append(test_avg_acc)

        if epoch % 20 == 0:
            print('[%3d/%3d] train_avg_loss %.4f | train_avg_acc %.4f |  test_avg_loss %.4f | test_avg_acc %.4f |'
                  % (epoch, cf["train"]["niter"], train_avg_loss, train_avg_acc, test_avg_loss, test_avg_acc))
        if epoch % 200 == 0:
            torch.save(
                net.state_dict(), '{0}/net_epoch_{1}.pth'.format(cf["checkpoint_path"], epoch))
    return cf, net, Q, dim_full


def get_optimizer(cf, net, epoch):
    if cf["train"]["decay"] == "l1":
        weight_decay = 0
    else:
        weight_decay = cf["train"]["weight_decay"]
    
    if cf["model"]["type"] in ["ntk", "fix"] or real_dataset:
        lr = cf["train"]["lr"]
    else:
        if epoch == 1:
            if cf["data"]["label_function"] == "parity":
                lr = 1 / (cf["data"]["effective_dim"] * cf["data"]["feature_dim"] * cf["model"]["hidden_dim"])
            elif cf["data"]["label_function"] == "interval":
                lr = 10 / (cf["data"]["effective_dim"] * cf["data"]["feature_dim"] * cf["model"]["hidden_dim"])
            elif cf["data"]["label_function"] == "all":
                lr = 10 / (cf["data"]["effective_dim"] * cf["data"]["feature_dim"] * cf["model"]["hidden_dim"])
            else:
                lr = 1 / (3 * cf["data"]["effective_dim"] * cf["data"]["feature_dim"] * cf["model"]["hidden_dim"])
        elif epoch == 2:
            lr = 1
        else:
            lr = 1 / (epoch * (cf["model"]["hidden_dim"]/2)**0.5)

    optimizer = torch.optim.SGD(
        net.parameters(), lr=lr, momentum=0.0, weight_decay=weight_decay)  # TODO

    return optimizer, lr

def train_model(cf, epoch, trainloader, optimizer, net, criterion, time, USE_CUDA, teacher=None):
    if epoch <= 2 or optimizer is None:
        optimizer, lr = get_optimizer(cf, net, epoch)
    net.train()

    it = 0
    avg_loss = AverageMeter()
    avg_acc = AverageMeter()
    for data in trainloader:
        x,  label = process_data(data, USE_CUDA, cf)
        x = Variable(x)
        # ===================forward=====================
        y, _ = net(x, epoch)
        loss = criterion(y, label)
        loss_item = loss.item()

        if not (cf["model"]["type"] in ["ntk", "fix"] or real_dataset):
            if epoch == 1:
                loss += 1/(2*lr) * torch.sum(net.fc1.weight * net.fc1.weight)
            elif epoch == 2:
                loss += 1/(2*lr) * (torch.sum(net.fc1.weight * net.fc1.weight) + torch.sum(net.fc2.weight * net.fc2.weight))


        n = x.shape[0]
        avg_loss.update(loss_item, n)
        acc1 = accuracy(y, label)[0]
        avg_acc.update(acc1.item(), n)

        if cf["train"]["decay"] == "l1" and cf["train"]["weight_decay"] > 0:
            loss += cf["train"]["weight_decay"] * \
                torch.sum(torch.abs(net.fc.weight))

        # ===================backward====================
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        it += 1
        if it % 100 == 0:
            run_time = (timeit.default_timer() - time) / 60.0
            print('[%3d/%3d][%3d/%3d] (%.2f m) loss %.4f | acc %.4f'
                  % (epoch, cf["train"]["niter"], it, len(trainloader),
                     run_time, loss_item, acc1))
        break
    return avg_loss.avg, avg_acc.avg


def test_model(cf, epoch, net, testloader, criterion, Q, USE_CUDA, teacher=None):
    avg_loss = AverageMeter()
    avg_acc = AverageMeter()
    with torch.no_grad():
        for batch_idx, data in enumerate(testloader):
            x,  label = process_data(data, USE_CUDA,  cf)
            n = x.shape[0]
            y, _ = net(x)
            
            loss = criterion(y, label)

            avg_loss.update(loss.item(), n)
            acc1 = accuracy(y, label)[0]
            avg_acc.update(acc1.item(), n)
            break

    if (epoch <= 3 or epoch % 10 == 0): 
        test_vectors = net.fc1.weight.detach().clone()
        for i in range(test_vectors.shape[0]):
            if torch.norm(test_vectors[i]) == 0 :
                continue
            test_vectors[i] = test_vectors[i] / torch.norm(test_vectors[i])

        if teacher is not None:
            cos = nn.CosineSimilarity(dim=1, eps=1e-6)
            
            cos_simlarity = cos(test_vectors, good_weights_repeat)
            cos_simlarity = float(torch.max(cos_simlarity))
            print("cos_simlarity", cos_simlarity)
            cos_simlarity_list.append(cos_simlarity)
            test_vectors = torch.cat(
                (test_vectors, good_weights_exp, -good_weights_exp), dim=0)

        torch.save(test_vectors, os.path.join(cf["fig_path"], "X_" + str(epoch) + ".pt"))

        if cf["model"]["type"] in ["adp"] or real_dataset:
            test_vectors = test_vectors.detach().numpy()
            if not real_dataset:
                MDS_plot(test_vectors, os.path.join(cf["fig_path"], "MDS_" + str(epoch)+ ".png"))
                TSNE_plot(test_vectors, os.path.join(cf["fig_path"], "TSNE_" + str(epoch)+ ".png"))
            else:
                MDS_plot(test_vectors, os.path.join(cf["fig_path"], "MDS_" + str(epoch)+ ".png"), red = False)
                TSNE_plot(test_vectors, os.path.join(cf["fig_path"], "TSNE_" + str(epoch) + ".png"), red = False)
        
    return avg_loss.avg, avg_acc.avg



if __name__ == '__main__':
    print(torch.__version__)

    parser = argparse.ArgumentParser("Training model")
    parser.add_argument(
        "--config", default="./conf/sample.yaml", help="Config file.")
    args = parser.parse_args()
    cf, net, Q, dim_full = train(args=args)

    exp_results = {}
    exp_results["cf"] = cf

    plt.clf()
    plt.plot(train_losses, "g", label='train loss')
    plt.plot(test_losses, "b", label='test loss')
    exp_results["train_losses"] = train_losses
    exp_results["test_losses"] = test_losses
    plt.legend()
    plt.savefig(os.path.join(cf["fig_path"], "loss.png"))
    
    plt.clf()
    plt.plot(train_acces, "g", label='train acc')
    plt.plot(test_acces, "b", label='test acc')
    exp_results["train_acces"] = train_acces
    exp_results["test_acces"] = test_acces
    plt.legend()
    plt.savefig(os.path.join(cf["fig_path"], "acc.png"))

    if len(cos_simlarity_list) > 0:
        plt.clf()
        plt.plot(cos_simlarity_list, "g", label='cos simlarity')
        exp_results["cos_simlarity_list"] = cos_simlarity_list
        plt.legend()
        plt.savefig(os.path.join(cf["fig_path"], "cos.png"))
    
    torch.save(exp_results, os.path.join(cf["fig_path"], "exp_results.pt"))
