import os
import sys
import copy
import argparse
import wandb
import torch
import torch.nn.functional as F
import numpy as np
import scipy
import random
import matplotlib.pyplot as plt

import datasets
import utils

from utils import mlp_agop
from tqdm import tqdm
from models import neural_nets

torch.set_default_dtype(torch.float32)
torch.manual_seed(3143)
random.seed(253)
np.random.seed(1145)

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--wandb_entity', default='default')
    parser.add_argument('--wandb_proj_name', default='default')
    parser.add_argument('--wandb_offline', default=False, action='store_true')
    parser.add_argument('--group_key', default='', type=str)
    parser.add_argument('--out_dir', default='./wandb')
    parser.add_argument('--data_root', default='./data')
    parser.add_argument('--device', default='cuda', choices={'cuda', 'cpu'})

    parser.add_argument('--dataset', default='modular_arithmetic')
    parser.add_argument('--operation', '-op', default="x+y")
    parser.add_argument('--prime', '-p', default=61, type=int)
    parser.add_argument('--training_fraction', default=0.5, type=float)
    parser.add_argument('--batch_size', default=32, type=int)
    parser.add_argument('--agop_batch_size', default=32, type=int)

    parser.add_argument('--model', default='FCN', choices={'FCN', 'CNN'})
    parser.add_argument('--epochs', default=1000, type=int)
    parser.add_argument('--agop_log_freq', default=10, type=int)
    parser.add_argument('--n_hid_layers', default=1, type=int)
    parser.add_argument('--hidden_width', default=256, type=int)
    parser.add_argument('--init_scale', default=1.0, type=float)
    parser.add_argument("--act_fn", type=str, default="quadratic", choices={'relu', 'quadratic', 'swish', 'softplus', 'linear'})

    parser.add_argument('--opt', default='adamw', choices={'sgd', 'adamw'})
    parser.add_argument('--loss', default='mse', choices={'mse', 'xent'})
    parser.add_argument('--learning_rate', default=1e-3, type=float)
    parser.add_argument('--weight_decay', default=1.0, type=float)
    parser.add_argument('--momentum', default=0.0, type=float)
    parser.add_argument('--scheduler', default='none')
    args = parser.parse_args()

    utils.setup_wandb(wandb, args)

    train_dataset, test_dataset, inp_dim, out_dim = datasets.load_dataset(args)
    train_loader = datasets.make_dataloader(train_dataset, args.batch_size, shuffle=True, drop_last=False)
    agop_loader = datasets.make_dataloader(copy.deepcopy(train_dataset), args.agop_batch_size, shuffle=False, drop_last=True)
    test_loader = datasets.make_dataloader(test_dataset, args.batch_size, shuffle=False, drop_last=False)

    sigmoid_output = False
    if args.loss == 'xent' and out_dim == 1:
        sigmoid_output = True

    model = neural_nets.FCN(
        inp_dim=inp_dim,
        hidden_width=args.hidden_width,
        out_dim=out_dim,
        n_hid_layers=args.n_hid_layers,
        init_scale=args.init_scale,
        sigmoid_output=sigmoid_output
    ).to(args.device)

    optimizer = utils.get_optimizer(args, model)

    if args.loss == 'mse':
        criterion = torch.nn.MSELoss()
    elif args.loss == 'xent':
        if out_dim == 1:
            criterion = torch.nn.BCELoss()
        else:
            criterion = torch.nn.CrossEntropyLoss()
    else:
        raise

    if args.scheduler == 'cosine':
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, int(args.epochs))
    elif args.scheduler == 'step1':
        # total_steps = args.epochs * len(train_loader)
        # milestones = [int(total_steps/4), int(total_steps/2)]
        milestones = [75, 90]
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=0.1)

    global_step = 0
    for epoch in tqdm(range(args.epochs)):
        model.train()
        for idx, batch in enumerate(train_loader):
            batch = tuple(t.to(args.device) for t in batch)
            inputs, labels = batch
            labels = labels.to(inputs.dtype)

            optimizer.zero_grad()
            output = model(inputs, act_fn=args.act_fn)

            if out_dim == 1:
                output = output.squeeze()
                if args.loss == 'mse':
                    # write single class / one logit output by hand
                    preds = (output > 0).long()
                    labs = (labels > 0).long()
                    count = (preds == labs).sum()
                    acc = count / output.shape[0]
                    loss = torch.pow(output - labels, 2).mean()
                else:
                    preds = (output > 0.5).long()
                    labs = (labels > 0.5).long()
                    count = (preds == labs).sum()
                    acc = count / output.shape[0]
                    loss = criterion(output, labels)
            else:
                count = (output.argmax(-1) == labels.argmax(-1)).sum()
                acc = count / output.shape[0]

                if args.loss == 'mse':
                    loss = criterion(output, labels)
                else:
                    loss = criterion(output, labels.argmax(-1).long())

            fc1_w_norm = torch.linalg.norm(model.fc1.weight.data).detach()
            out_w_norm = torch.linalg.norm(model.out.weight.data).detach()

            loss.backward()

            fc1_grad_w_norm = torch.linalg.norm(model.fc1.weight.grad.data).detach()
            out_grad_w_norm = torch.linalg.norm(model.out.weight.grad.data).detach()

            optimizer.step()

            wandb.log({
                'training/batch_accuracy': acc,
                'training/batch_loss': loss,
                'training/fc1_w_norm': fc1_w_norm,
                'training/out_w_norm': out_w_norm,
                'training/fc1_grad_w_norm': fc1_grad_w_norm,
                'training/out_grad_w_norm': out_grad_w_norm,
                'epoch': epoch,
                'learning_rate': optimizer.param_groups[-1]['lr']
            }, step=global_step)

            global_step += 1

        if args.scheduler != 'none':
            if not (args.scheduler == 'cosine' and epoch >= int(args.epochs)):
                scheduler.step()

        model.eval()
        with torch.no_grad():
            count = 0
            total_loss = 0
            total = 0
            for idx, batch in enumerate(test_loader):
                batch = tuple(t.to(args.device) for t in batch)
                inputs, labels = batch
                labels = labels.to(inputs.dtype)

                output = model(inputs, act_fn=args.act_fn)

                total += inputs.shape[0]
                if out_dim == 1:
                    output = output.squeeze()
                    if args.loss == 'mse':
                        # write single class / one logit output by hand
                        preds = (output > 0).long()
                        labs = (labels > 0).long()
                        count += (preds == labs).sum()
                        #acc = count / output.shape[0]
                        loss = torch.pow(output - labels, 2).mean()
                    else:
                        preds = (output > 0.5).long()
                        labs = (labels > 0.5).long()
                        count += (preds == labs).sum()

                        loss = criterion(output, labels)
                else:
                    count += (output.argmax(-1) == labels.argmax(-1)).sum()
                    #acc = count / output.shape[0]
                    if args.loss == 'mse':
                        loss = criterion(output, labels)
                    else:
                        loss = criterion(output, labels.argmax(-1).long())

                total_loss += loss * output.shape[0]

            total_loss /= total
            acc = count / total

            wandb.log({
                'validation/accuracy': acc,
                'validation/loss': total_loss,
                'epoch': epoch
            }, step=global_step)

        if epoch % args.agop_log_freq == 0:
            nfm = model.fc1.weight.data.T @ model.fc1.weight.data
            nfm = nfm.detach().cpu().numpy()

            with torch.no_grad():
                agop, _ = mlp_agop.calc_full_agop(model, agop_loader, args,
                                                    calc_per_class_agops=False, detach=True)
            agop = agop.cpu().numpy()

            lgop, enfa = mlp_agop.calc_lgop(model, agop_loader, criterion, optimizer, args, out_dim, act_fn=args.act_fn)

            lgop3 = 0.5 * (lgop + lgop.t())
            lgop2 = lgop @ lgop.t()
            lgop2 = lgop2.cpu().numpy()
            lgop = lgop.cpu().numpy()
            enfa = enfa.cpu().numpy()
            lgop3 = lgop3.cpu().numpy()

            wagop = lgop @ nfm

            sqrt_agop = utils.matrix_power(agop, 0.5, is_torch=False)
            sqrt_wagop = utils.matrix_power(wagop, 0.5, is_torch=False)
            sqrt_enfa = utils.matrix_power(enfa, 0.5, is_torch=False)
            sqrt_lgop2 = utils.matrix_power(lgop2, 0.5, is_torch=False)

            agop_nfm_corr = np.corrcoef(sqrt_agop.flatten(), nfm.flatten())[0][1]
            wagop_nfm_corr = np.corrcoef(sqrt_wagop.flatten(), nfm.flatten())[0][1]
            lgop_nfm_corr = np.corrcoef(lgop.flatten(), nfm.flatten())[0][1]
            # enfa_nfm_corr = np.corrcoef(enfa.flatten(), nfm.flatten())[0][1]
            sqrt_enfa_nfm_corr = np.corrcoef(sqrt_enfa.flatten(), nfm.flatten())[0][1]
            lgop2_nfm_corr = np.corrcoef(sqrt_lgop2.flatten(), nfm.flatten())[0][1]
            lgop3_nfm_corr = np.corrcoef(lgop3.flatten(), nfm.flatten())[0][1]

            wandb.log({
                'NFA/sqrt_agop_nfm_corr': agop_nfm_corr,
                'NFA/sqrt_wagop_nfm_corr': wagop_nfm_corr,
                'NFA/lgop_nfm_corr': lgop_nfm_corr,
                # 'NFA/enfa_nfm_corr': enfa_nfm_corr,
                'NFA/sqrt_enfa_nfm_corr': sqrt_enfa_nfm_corr,
                'NFA/sqrt_lgop_lgopT_nfm_corr': lgop2_nfm_corr,
                'NFA/lgop3_nfm_corr': lgop3_nfm_corr,
                'epoch': epoch
            }, step=global_step)

            # plt.clf()
            # plt.imshow(utils.undiag(nfm))
            # plt.colorbar()
            # img = wandb.Image(plt, caption=f'nfm')
            # wandb.log({'nfm': img}, step=global_step)
            #
            # plt.clf()
            # plt.imshow(utils.undiag(sqrt_agop))
            # plt.colorbar()
            # img = wandb.Image(plt, caption=f'sqrt_agop')
            # wandb.log({'sqrt_agop': img}, step=global_step)
            #
            # plt.clf()
            # plt.imshow(utils.undiag(sqrt_wagop))
            # plt.colorbar()
            # img = wandb.Image(plt, caption=f'sqrt_wagop')
            # wandb.log({'sqrt_wagop': img}, step=global_step)
            #
            # plt.clf()
            # plt.imshow(utils.undiag(lgop))
            # plt.colorbar()
            # img = wandb.Image(plt, caption=f'lgop')
            # wandb.log({'lgop': img}, step=global_step)

if __name__=='__main__':
    main()
