import os
import time
import json
from math import prod
import wandb
import torch
from torch.nn import CrossEntropyLoss
from model import set_seed
from nn import cola_nn
import nn
from tqdm import tqdm
from torchinfo import summary
from fvcore.nn import FlopCountAnalysis
from scaling_mlps.data_utils import data_stats
from scaling_mlps.data_utils.dataloader import get_loader
from scaling_mlps.utils.config import config_to_name
from scaling_mlps.utils.get_compute import get_compute
from scaling_mlps.utils.metrics import topk_acc, real_acc, AverageMeter
from scaling_mlps.utils.optimizer import get_optimizer, get_scheduler
from scaling_mlps.utils.parsers import get_training_parser


def train(model, opt, scheduler, loss_fn, epoch, train_loader, args):
    start = time.time()
    model.train()

    total_acc, total_top5 = AverageMeter(), AverageMeter()
    total_loss = AverageMeter()

    for step, (ims, targs) in enumerate(train_loader):
        preds = model(ims)

        if args.mixup > 0:
            targs_perm = targs[:, 1].long()
            weight = targs[0, 2].squeeze()
            targs = targs[:, 0].long()
            if weight != -1:
                loss = loss_fn(preds, targs) * weight + loss_fn(preds, targs_perm) * (1 - weight)
            else:
                loss = loss_fn(preds, targs)
                targs_perm = None
        else:
            loss = loss_fn(preds, targs)
            targs_perm = None

        acc, top5 = topk_acc(preds, targs, targs_perm, k=5, avg=True)
        total_acc.update(acc, ims.shape[0])
        total_top5.update(top5, ims.shape[0])

        loss = loss / args.accum_steps
        loss.backward()

        if (step + 1) % args.accum_steps == 0 or (step + 1) == len(train_loader):
            if args.clip > 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
            opt.step()
            opt.zero_grad()

        total_loss.update(loss.item() * args.accum_steps, ims.shape[0])

    end = time.time()

    scheduler.step()

    return (
        total_acc.get_avg(percentage=True),
        total_top5.get_avg(percentage=True),
        total_loss.get_avg(percentage=False),
        end - start,
    )


@torch.no_grad()
def test(model, loader, loss_fn, args):
    start = time.time()
    model.eval()
    total_acc, total_top5, total_loss = AverageMeter(), AverageMeter(), AverageMeter()

    for ims, targs in loader:
        preds = model(ims)
        if args.dataset != 'imagenet_real':
            acc, top5 = topk_acc(preds, targs, k=5, avg=True)
            loss = loss_fn(preds, targs).item()
        else:
            acc = real_acc(preds, targs, k=5, avg=True)
            top5 = 0
            loss = 0

        total_acc.update(acc, ims.shape[0])
        total_top5.update(top5, ims.shape[0])
        total_loss.update(loss)

    end = time.time()

    return (
        total_acc.get_avg(percentage=True),
        total_top5.get_avg(percentage=True),
        total_loss.get_avg(percentage=False),
        end - start,
    )


def main(args):
    set_seed(args.seed)
    # Use mixed precision matrix multiplication
    torch.backends.cuda.matmul.allow_tf32 = True
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # model = get_architecture(**args.__dict__).cuda()
    input_size = (1, 3, args.crop_resolution, args.crop_resolution)
    model = getattr(nn, args.model)(dim_in=prod(input_size), dim_out=args.num_classes, depth=args.depth,
                                    width=args.width, patch_size=args.patch_size, in_channels=args.in_channels)
    input = torch.randn(*input_size).cuda()
    stats = summary(model, input_size)
    base_params = stats.trainable_params
    base_flops = FlopCountAnalysis(model, input).total()

    # Count number of parameters for logging purposes
    args.num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    cola_nn.colafy(model, struct=args.struct, layers=args.layers, rank_frac=args.rank_frac, kron_mult=args.kron_mult,
                   tt_dim=args.tt_dim, tt_rank=args.tt_rank)
    print('CoLA model:')
    stats = summary(model, input_size)
    cola_params = stats.trainable_params
    cola_flops = FlopCountAnalysis(model, input).total()

    print(f'Base params: {base_params / 1e6:.2f}M | CoLA params: {cola_params / 1e6:.2f}M')
    print(f'Base flops: {base_flops / 1e6:.2f}M | CoLA flops: {cola_flops / 1e6:.2f}M')

    # Create unique identifier
    name = config_to_name(args)
    path = os.path.join(args.checkpoint_folder, name)

    # Create folder to store the checkpoints
    if not os.path.exists(path):
        os.makedirs(path)
        with open(path + '/config.txt', 'w') as f:
            json.dump(args.__dict__, f, indent=2)

    # Get the dataloaders
    local_batch_size = args.batch_size // args.accum_steps

    train_loader = get_loader(args.dataset, bs=local_batch_size, mode="train", augment=args.augment, dev=device,
                              num_samples=args.n_train, mixup=args.mixup, data_path=args.data_path,
                              data_resolution=args.resolution, crop_resolution=args.crop_resolution)

    test_loader = get_loader(args.dataset, bs=local_batch_size, mode="test", augment=False, dev=device,
                             data_path=args.data_path, data_resolution=args.resolution,
                             crop_resolution=args.crop_resolution)

    start_ep = 1
    if args.reload:
        try:
            params = torch.load(path + "/name_of_checkpoint")
            model.load_state_dict(params)
            start_ep = 350
        except FileNotFoundError:
            print("No pretrained model found, training from scratch")

    opt = get_optimizer(args.optimizer)(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    scheduler = get_scheduler(opt, args.scheduler, **args.__dict__)

    loss_fn = CrossEntropyLoss(label_smoothing=args.smooth)

    if args.wandb:
        info = {
            'base_params': base_params,
            'base_flops': base_flops,
            'cola_params': cola_params,
            'cola_flops': cola_flops
        }
        config = args.__dict__
        config.update(info)
        wandb.init(
            project=args.wandb_project,
            name=name,
            config=config,
            tags=["pretrain", args.dataset],
        )

    compute_per_epoch = get_compute(model, args.n_train, args.crop_resolution)

    prev_hs = None
    for ep in (pb := tqdm(range(start_ep, args.epochs))):
        calc_stats = (ep + 1) % args.calculate_stats == 0

        current_compute = compute_per_epoch * ep

        train_acc, train_top5, train_loss, train_time = train(model, opt, scheduler, loss_fn, ep, train_loader, args)

        if calc_stats:
            model.hs = [[] for _ in range(len(model.hs))]  # clear the list
            test_acc, test_top5, test_loss, test_time = test(model, test_loader, loss_fn, args)
            # get features on test set
            hs = model.hs  # list of lists of tensors
            hs = [torch.cat(h, dim=0) for h in hs]  # list of tensors
            if prev_hs is None:
                prev_hs = hs
            dhs = [hs[i] - prev_hs[i] for i in range(len(hs))]
            h_norm = [torch.norm(h, dim=1).mean() / h.shape[1]**0.5 for h in hs]  # should be O(1)
            dh_norm = [torch.norm(dh, dim=1).mean() / dh.shape[1]**0.5 for dh in dhs]  # should be O(1)
            prev_hs = hs
            if args.wandb:
                logs = {
                    "epoch": ep,
                    "train_acc": train_acc,
                    "test_acc": test_acc,
                    "test_loss": test_loss,
                    "current_compute": current_compute,
                    "Inference time": test_time,
                }
                for i in range(len(h_norm)):
                    logs[f'h_{i}'] = h_norm[i].item()
                    logs[f'dh_{i}'] = dh_norm[i].item()
                wandb.log(logs)
            pb.set_description(f"Epoch {ep}, Train Acc: {train_acc:.2f}, Test Acc: {test_acc:.2f}")

    if args.save:
        torch.save(
            model.state_dict(),
            path + "/final_checkpoint.pt",
        )


if __name__ == "__main__":
    parser = get_training_parser()
    args = parser.parse_args()

    args.num_classes = data_stats.CLASS_DICT[args.dataset]

    if args.n_train is None:
        args.n_train = data_stats.SAMPLE_DICT[args.dataset]

    if args.crop_resolution is None:
        args.crop_resolution = args.resolution

    main(args)
