from torchvision.transforms.functional import pil_to_tensor
from torchvision.datasets import MNIST, FashionMNIST
from torch.utils.data import DataLoader
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from utils import eval_model_classification, create_generator_from_config
from MLP import MLP
import numpy as np
import tqdm
from pytorch_module_generator_wrapper import wrap_torch_model_lora, wrap_torch_model, get_n_sections_for_model_lora, get_n_sections_for_model
import os
import argparse
from cifar10 import load_cifar10
from imagenet_100 import load_imagenet_100
from cifar100 import load_cifar100
from models import resnet18, resnet56, resnet20
from timm.models.vision_transformer import VisionTransformer
from functools import reduce, partial
from timm.data import Mixup


device = torch.device('cuda:2')

n_hidden = 512
arch = [784, n_hidden, n_hidden, 10]
bias = True


def transform_sample(x):
    return x.squeeze().flatten().float()


def load_mnist():
    transforms = [pil_to_tensor]
    transforms.append(transform_sample)
    transforms = torchvision.transforms.Compose(transforms)

    mnist_train = MNIST(root="./data", train=True, download=True, transform=transforms)
    mnist_test = MNIST(root="./data", train=False, download=False, transform=transforms)

    return mnist_train, mnist_test


def convert_state_dict_to_old(state_dict):
    basis_2_weight = state_dict['0.basis.2.weight']
    basis_2_bias = state_dict['0.basis.2.bias']
    del state_dict['0.basis.2.weight']
    del state_dict['0.basis.2.bias']
    state_dict['0.basis.1.weight'] = basis_2_weight
    state_dict['0.basis.1.bias'] = basis_2_bias
    return state_dict

def load_dataset(dataset, model_type, **kwargs):
    if dataset == 'in100':
        return load_imagenet_100(data_path=kwargs['data_path'], simple_augmentation=('vit' not in model_type)), 100
    elif dataset == 'cifar10':
        return load_cifar10(simple_augmentation=True), 10
    elif dataset == 'cifar100':
        return load_cifar100(simple_augmentation=True), 100
    elif dataset == 'mnist':
        return load_mnist(), 10
    else:
        raise "Invalid dataset"


def run_trial(model_type, gen_config, dataset, width, sched_type, train_amplitude, n_sum_enc, n_sum_class, data_path,
              use_lora, lora_rank, epochs, normalize, lr, batch_size, add_init, use_mixup, mlp_hidden_dim):
    global n_sections

    if use_lora:
        if model_type == 'vit':
            get_n_sections = partial(get_n_sections_for_model_lora, rank=lora_rank, use_large_lora_cnn=True)
            wrap_model = partial(wrap_torch_model_lora, rank=lora_rank, use_large_lora_cnn=True)
        else:
            get_n_sections = partial(get_n_sections_for_model_lora, rank=lora_rank)
            wrap_model = partial(wrap_torch_model_lora, rank=lora_rank)
    else:
        get_n_sections = get_n_sections_for_model
        wrap_model = wrap_torch_model

    (train_ds, test_ds), n_classes = load_dataset(dataset=dataset, model_type=model_type, data_path=data_path)

    if model_type == 'resnet20':
        model = resnet20(num_classes=n_classes).to(device)
        exclude_modules = [m for n, m in model.named_modules() if
                           'LayerNorm' in type(m).__name__ or 'BatchNorm' in type(m).__name__]
        exclude_module_params = reduce(lambda x, y: list(x) + list(y), [n.parameters() for n in exclude_modules], [])
        classifier = model.linear
        model.linear = nn.Identity()
    elif model_type == 'resnet18':
        model = resnet18(num_classes=n_classes).to(device)
        exclude_modules = [m for n, m in model.named_modules() if
                           'LayerNorm' in type(m).__name__ or 'BatchNorm' in type(m).__name__]
        exclude_module_params = reduce(lambda x, y: list(x) + list(y), [n.parameters() for n in exclude_modules], [])
        classifier = model.fc
        model.fc = nn.Identity()
    elif model_type == 'resnet56':
        model = resnet56(num_classes=n_classes).to(device)
        exclude_modules = [m for n, m in model.named_modules() if
                           'LayerNorm' in type(m).__name__ or 'BatchNorm' in type(m).__name__]
        exclude_module_params = reduce(lambda x, y: list(x) + list(y), [n.parameters() for n in exclude_modules], [])
        classifier = model.linear
        model.linear = nn.Identity()
    elif model_type == 'vit_ti':
        model = VisionTransformer(
            num_classes=n_classes, patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True,
            norm_layer=partial(nn.LayerNorm, eps=1e-6)).to(device)
        exclude_modules = [m for n, m in model.named_modules() if
                           'LayerNorm' in type(m).__name__ or 'BatchNorm' in type(m).__name__ or 'LayerScale' in type(m).__name__]
        exclude_module_params = reduce(lambda x, y: list(x) + list(y), [n.parameters() for n in exclude_modules], [])
        exclude_module_params += [model.cls_token]
        exclude_module_params += [model.pos_embed]
        classifier = model.head
        model.head = nn.Identity()
    elif model_type == 'vit_s':
        model = VisionTransformer(
            num_classes=n_classes, patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True,
            norm_layer=partial(nn.LayerNorm, eps=1e-6)).to(device)
        exclude_modules = [m for n, m in model.named_modules() if
                           'LayerNorm' in type(m).__name__ or 'BatchNorm' in type(m).__name__ or 'LayerScale' in type(m).__name__]
        exclude_module_params = reduce(lambda x, y: list(x) + list(y), [n.parameters() for n in exclude_modules], [])
        exclude_module_params += [model.cls_token]
        exclude_module_params += [model.pos_embed]
        classifier = model.head
        model.head = nn.Identity()
    elif model_type == 'mlp':
        assert dataset == 'mnist'
        model = MLP(architecture=[784, mlp_hidden_dim, mlp_hidden_dim, 10], final_activation=nn.Identity())
        model = model.to(device)
        exclude_modules = []
        exclude_module_params = []
        classifier = model.regressor[0]
        model.regressor[0] = nn.Identity()
    else:
        raise 'Invalid model type'
    generator = create_generator_from_config(gen_config)
    generator = generator.to(device)
    gen_in = generator.architecture[0]
    gen_size = generator.architecture[-1]

    composed_model = nn.Sequential(model, classifier)

    n_sections_enc = get_n_sections(model, gen_size, exclude_modules=exclude_modules)
    n_sections_class = get_n_sections(classifier, gen_size, exclude_modules=exclude_modules)

    skip_last = True if use_mixup else False

    train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=4, drop_last=skip_last)
    test_dl = DataLoader(test_ds, batch_size=batch_size, shuffle=False, pin_memory=True, num_workers=4)

    n_k_enc = n_sections_enc * n_sum_enc
    n_k_class = n_sections_class * n_sum_class

    k_enc = width * torch.rand((n_sections_enc, n_sum_enc, gen_in)) - (width // 2)
    k_class = width * torch.rand((n_sections_class, n_sum_class, gen_in)) - (width // 2)
    k = torch.tensor(torch.cat([k_enc.view(-1, gen_in), k_class.view(-1, gen_in)], dim=0), requires_grad=True,
                     device=device)

    mixup_fn = Mixup(
        mixup_alpha=0.8, cutmix_alpha=1.0, cutmix_minmax=None,
        prob=1.0, switch_prob=0.5, mode='batch',
        label_smoothing=0.1, num_classes=100)

    if train_amplitude:
        a_enc = torch.nn.functional.normalize(torch.rand((n_sections_enc, n_sum_enc, 1)), dim=1, p=1.0)
        a_class = torch.nn.functional.normalize(torch.rand((n_sections_class, n_sum_class, 1)), dim=1, p=1.0)
        a = torch.cat([a_enc.view(-1, 1), a_class.view(-1, 1)], dim=0).clone().detach().to(device).requires_grad_(True)
        optimizer = optim.Adam(params=[k, a] + exclude_module_params, lr=lr)
    else:
        a = torch.ones((k.size(0), 1), requires_grad=False, device=device)
        optimizer = optim.Adam(params=[k] + exclude_module_params, lr=lr)


    if sched_type == 'plateau':
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer, factor=0.5, patience=4)
    elif sched_type == 'step':
        scheduler = optim.lr_scheduler.StepLR(optimizer=optimizer, step_size=50, gamma=0.5)
    elif sched_type == 'cosine':
        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=epochs, eta_min=1e-6)
    else:
        raise 'Invalide scheduler type'
    if train_amplitude:
        n_params = torch.numel(k) + torch.numel(a)
        print(f"Trainable params: {torch.numel(k) + torch.numel(a)}")
    else:
        n_params = torch.numel(k)
        print(f"Trainable params: {torch.numel(k)}")
    losses_per_epoch = []
    for i in range(epochs):
        total_loss = 0.0
        for x, y in tqdm.tqdm(train_dl):
            x = x.to(device)
            y = y.to(device)

            if use_mixup:
                x, y = mixup_fn(x, y)

            generator.zero_grad()
            generated_weights = generator(k.view(-1, gen_in)).view(-1, gen_size)
            if normalize:
                generated_weights = generated_weights / (torch.linalg.norm(generated_weights, dim=1) + 1e-5).unsqueeze(-1)
            generated_weights = generated_weights * a
            generated_weights_enc = generated_weights[:n_k_enc].view(n_sections_enc, n_sum_enc, -1).sum(dim=1)
            generated_weights_class = (generated_weights[n_k_enc:n_k_enc + n_k_class]
                                       .view(n_sections_class, n_sum_class, -1).sum(dim=1))

            wrap_model(model, generated_weights_enc, 0, exclude_modules=exclude_modules, add_init_weights=add_init)
            wrap_model(classifier, generated_weights_class, 0, exclude_modules=exclude_modules, add_init_weights=add_init)
            logits = classifier(model(x))
            loss = F.cross_entropy(logits, y)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.detach().cpu().item()
        print(f"loss epoch {i}: {total_loss}")
        losses_per_epoch.append(total_loss / len(train_dl))
        if (i + 1) % 5 == 0:
            generator.zero_grad()
            generated_weights = generator(k.view(-1, gen_in)).view(-1, gen_size)
            if normalize:
                generated_weights = generated_weights / (torch.linalg.norm(generated_weights, dim=1) + 1e-5).unsqueeze(-1)
            generated_weights = generated_weights * a
            generated_weights_enc = generated_weights[:n_k_enc].view(n_sections_enc, n_sum_enc, -1).sum(dim=1)
            generated_weights_class = (generated_weights[n_k_enc:n_k_enc + n_k_class]
                                       .view(n_sections_class, n_sum_class, -1).sum(dim=1))

            wrap_model(model, generated_weights_enc, 0, exclude_modules=exclude_modules, add_init_weights=add_init)
            wrap_model(classifier, generated_weights_class, 0, exclude_modules=exclude_modules, add_init_weights=add_init)
            # train_acc = eval_model_classification(composed_model, train_dl, device)
            test_acc = eval_model_classification(composed_model, test_dl, device)
            # print(f"epoch: {i}, train acc: {train_acc}, test acc: {test_acc}")
            print(f"epoch: {i}, test acc: {test_acc}")
        if sched_type == 'plateau':
            scheduler.step(total_loss)
        else:
            scheduler.step()
    generator.zero_grad()
    generated_weights = generator(k.view(-1, gen_in)).view(-1, gen_size)
    if normalize:
        generated_weights = generated_weights / (torch.linalg.norm(generated_weights, dim=1) + 1e-5).unsqueeze(-1)
    generated_weights = generated_weights * a
    generated_weights_enc = generated_weights[:n_k_enc].view(n_sections_enc, n_sum_enc, -1).sum(dim=1)
    generated_weights_class = (generated_weights[n_k_enc:n_k_enc + n_k_class]
                               .view(n_sections_class, n_sum_class, -1).sum(dim=1))
    wrap_model(model, generated_weights_enc, 0, exclude_modules=exclude_modules, add_init_weights=add_init)
    wrap_model(classifier, generated_weights_class, 0, exclude_modules=exclude_modules, add_init_weights=add_init)
    train_acc = eval_model_classification(composed_model, train_dl, device)
    test_acc = eval_model_classification(composed_model, test_dl, device)
    print(f"Final train acc: {train_acc}, test acc: {test_acc}")

    return train_acc, test_acc, losses_per_epoch, n_params


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--gen-config', type=str)
    parser.add_argument('--model', type=str, choices=['resnet18', 'vit_ti', 'vit_s', 'resnet20', 'resnet56', 'mlp'])
    parser.add_argument('--sched-type', type=str, choices=['plateau', 'step', 'cosine'], default='plateau')
    parser.add_argument('--n-sum-enc', type=int, default=1)
    parser.add_argument('--n-sum-class', type=int, default=1)
    parser.add_argument('--no-amplitude', dest='amplitude', action='store_false')
    parser.add_argument('--n-trials', type=int, default=3)
    parser.add_argument('--width', type=int, default=None)
    parser.add_argument('--dataset', type=str, choices=['in100', 'cifar10', 'cifar100', 'mnist'])
    parser.add_argument('--data-path', type=str, default='./imagenet-100')
    parser.add_argument('--lora', action='store_true')
    parser.add_argument('--lora-rank', type=int, default=16)
    parser.add_argument('--epochs', type=int, default=400)
    parser.add_argument('--normalize', action='store_true')
    parser.add_argument('--lr', type=float, default=0.05)
    parser.add_argument('--cuda', type=int)
    parser.add_argument('--batch-size', type=int, default=256)
    parser.add_argument('--no-add-init', action='store_false', dest='add_init')
    parser.add_argument('--mixup', action='store_true')
    parser.add_argument('--output-path', type=str)
    parser.add_argument('--mlp-hidden-dim', type=int, default=256)
    parser.set_defaults(amplitude=True, add_init=True)

    args = parser.parse_args()
    device = torch.device(f"cuda:{args.cuda}")
    torch.cuda.set_device(f'cuda:{args.cuda}')

    output_path = os.path.join(args.output_path)
    if not os.path.exists(output_path):
        os.makedirs(output_path, exist_ok=True)

    train_accs = []
    test_accs = []
    for i in range(args.n_trials):
        train_acc, test_acc, losses_per_epoch, n_params = run_trial(args.model, args.gen_config, args.dataset, args.width, args.sched_type, args.amplitude, args.n_sum_enc,
                                                          args.n_sum_class, args.data_path, args.lora, args.lora_rank, args.epochs,
                                                          args.normalize, args.lr, args.batch_size,
                                                          args.add_init, args.mixup, args.mlp_hidden_dim)
        train_accs.append(train_acc)
        test_accs.append(test_acc)

        with open(os.path.join(output_path, f"trial-{i}-accs-sched-type={args.sched_type}-n-sum-enc={args.n_sum_enc}-n-sum-class={args.n_sum_class}-width={args.width}.txt"), "w+") as f:
            f.write(str(train_acc) + "\n")
            f.write(str(test_acc) + "\n")
        with open(os.path.join(output_path, f"trial-{i}-sched-type={args.sched_type}-n-sum-enc={args.n_sum_enc}-n-sum-class={args.n_sum_class}-width={args.width}-loss.txt"), "w+") as f:
            for loss in losses_per_epoch:
                f.write(f"{loss}\n")
        if i == 0:
            with open(os.path.join(output_path, "num_params.txt"), "w+") as f:
                f.write(f"{n_params}\n")

    with open(os.path.join(output_path, f"Acc-sched-type={args.sched_type}-n-sum-enc={args.n_sum_enc}-n-sum-class={args.n_sum_class}-width={args.width}.txt"), "w+") as f:
        f.write(str(np.array(train_accs).mean()) + '\n')
        f.write(str(np.array(train_accs).std()) + '\n')
        f.write(str(np.array(test_accs).mean()) + '\n')
        f.write(str(np.array(test_accs).std()) + '\n')
    print(f"average train acc: {np.array(train_accs).mean()}, std: {np.array(train_accs).std()}")
    print(f"average test acc: {np.array(test_accs).mean()}, std: {np.array(test_accs).std()}")
