import torch
from torch import nn
from torch.utils.data import DataLoader
from torch.optim import AdamW
from torch.nn import functional as F
from avalanche.evaluation.metrics.accuracy import Accuracy
from tqdm import tqdm
import numpy as np
import random
import timm
from timm.models import create_model
from timm.scheduler.cosine_lr import CosineLRScheduler
from argparse import ArgumentParser
from data import *
import yaml
from pathlib import Path
from superlayer import ModuleInjection, SuperScalableLinear
import ast
from evolution_utils import EvolutionSearcher

def set_seed(seed=0):
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def set_opetl(model):
    layers = []
    for name, l in model.named_modules():
        if isinstance(l, nn.Linear):
            tokens = name.strip().split('.')
            layer = model
            for t in tokens[:-1]:
                if not t.isnumeric():
                    layer = getattr(layer, t)
                else:
                    layer = layer[int(t)]

            layers.append([layer, tokens[-1]])
    for parent_layer, last_token in layers:
        if not 'head' in last_token:
            setattr(parent_layer, last_token, ModuleInjection.make_scalable(getattr(parent_layer, last_token)))

def load(args, vit):
    weights = torch.load(args.load_path + args.dataset + '.pt')
    loaded = 0
    for n, p in vit.named_parameters():
        if any([x in n for x in ['A', 'B', 'C', 'D', 'E', 'head']]):
            p.data = weights[n]
            loaded +=1
    print(f'successfully loaded {loaded} parameters')
    return vit

if __name__ == '__main__':
    parser = ArgumentParser()
    parser.add_argument('--seed', type=int, default=1)
    parser.add_argument('--model', type=str, default='vit_base_patch16_224_in21k')
    parser.add_argument('--dataset', type=str, default='cifar')
    parser.add_argument('--save_path', type=str, default='models/temp/')
    parser.add_argument('--load_path', type=str, default='models/temp/')
    parser.add_argument('--max-epochs', type=int, default=20)
    parser.add_argument('--select-num', type=int, default=10)
    parser.add_argument('--population-num', type=int, default=50)
    parser.add_argument('--m_prob', type=float, default=0.2)
    parser.add_argument('--crossover-num', type=int, default=25)
    parser.add_argument('--epochs', type=int, default=30)
    parser.add_argument('--mutation-num', type=int, default=25)
    parser.add_argument('--param-limits', type=float, default=1.00)
    parser.add_argument('--min-param-limits', type=float, default=0)
    parser.add_argument('--evaluate', type=bool, default=False)
    args = parser.parse_args()
    seed = args.seed
    set_seed(seed)
    device = torch.device('cuda:0')
    name = args.dataset
    args.best_acc = 0
    Path(args.save_path).mkdir(parents=True, exist_ok=True)
    vit = create_model(args.model, checkpoint_path='./ViT-B_16.npz', drop_path_rate=0.1)
    train_dl, test_dl = get_data(name, evaluate=args.evaluate)
    set_opetl(vit)

    trainable = []
    vit.reset_classifier(get_classes_num(name))
    vit = load(args, vit)
    total_param = 0
    for n, p in vit.named_parameters():
        p.requires_grad = False
    print('total_param', total_param)
    choices = dict()
    choices['A'] = ['LoRA_2', 'vector', 'constant', 'none']
    choices['B'] = ['LoRA_2', 'vector', 'constant', 'none']
    choices['C'] = ['LoRA_2', 'vector', 'none']
    choices['D'] = ['constant', 'none', 'vector']
    choices['E'] = ['constant', 'none', 'vector']
    searcher = EvolutionSearcher(args, device, vit, choices, test_dl, args.save_path)
    searcher.search()