import argparse
import itertools
import os
import time
from collections import OrderedDict, defaultdict
from copy import deepcopy

import numpy as np
import PIL
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data
import torchvision.transforms as transforms
from evaluator import Evaluator
from torchvision.models.resnet import resnet18, ResNet18_Weights, resnet50, ResNet50_Weights
import torchvision
from tqdm import trange

from pymoo.algorithms.moo.nsga2 import NSGA2
from pymoo.algorithms.soo.nonconvex.cmaes import CMAES
from pymoo.algorithms.moo.sms import SMSEMOA
from pymoo.operators.crossover.pntx import TwoPointCrossover
from pymoo.operators.mutation.bitflip import BitflipMutation
from pymoo.operators.sampling.rnd import BinaryRandomSampling
from pymoo.optimize import minimize
from pymoo.core.problem import Problem
import numpy as np
#from train_cl import get_info, train, validate, train_moo
import ray
import wandb
from dataset import PickleDataset, HAM10000DatasetBalanced


torch.manual_seed(0)

@ray.remote(num_gpus=0.5)
def train_moo(moo_params, data_flag, num_epochs, batch_size, model_flag, tune_option, test_flag, train_frac, num_workers, lr):

    st_all_time = time.time()
    
    device = torch.device('cuda') 
    
    print('==> Preparing data...')
    
    
    # data_transform = transforms.Compose(
    #     [transforms.Resize((224, 224), interpolation=PIL.Image.Resampling.Bilinear), 
    #     transforms.ToTensor(),
    #     transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]
    # )
    # dataset = torchvision.datasets.ImageFolder("~/data/surgical/"+data_flag, data_transform)
    if data_flag == "HAM10000":
        dataset = HAM10000DatasetBalanced()
        # dataset = PickleDataset(f"cached_data/{data_flag}.pkl")
    else:
        dataset = PickleDataset(f"cached_data/{data_flag}.pkl")
    n_classes = len(dataset.classes)
    print(f"Dataset has {len(dataset)} images and {n_classes} classes")
    train_len = int(len(dataset) * train_frac)
    val_len = int(len(dataset) * 0.2)
    test_len = len(dataset) - train_len - val_len
    train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_len, val_len, test_len], generator=torch.Generator().manual_seed(41))
    
    train_loader = data.DataLoader(dataset=train_dataset,
                                batch_size=batch_size,
                                shuffle=True,
                                num_workers=num_workers,
                                pin_memory=True)
    val_loader = data.DataLoader(dataset=val_dataset,
                                batch_size=512,
                                shuffle=False, 
                                num_workers=num_workers,
                                pin_memory=True)
    test_loader = data.DataLoader(dataset=test_dataset,
                                batch_size=512,
                                shuffle=False,
                                num_workers=num_workers,
                                pin_memory=True)
    print('==> Building and training model...')
    
    
    if model_flag == 'resnet18':
        model =  resnet18(weights=ResNet18_Weights.DEFAULT)
    elif model_flag == 'resnet50':
        model =  resnet50(weights=ResNet50_Weights.DEFAULT)
    else:
        raise NotImplementedError
    model.fc = nn.Linear(model.fc.in_features, n_classes)
    model = model.to(device)
    model.train()

    train_evaluator = Evaluator(data_flag, 'train')
    val_evaluator = Evaluator(data_flag, 'val')
    test_evaluator = Evaluator(data_flag, 'test')

    criterion = nn.CrossEntropyLoss()

    if num_epochs == 0:
        return
    
    if tune_option in ['moo', 'SMSEMOA', 'lrmoo', "layer_first", "layer_last", "auto_rgn"]:
        trainable_params = get_trainable_params(model, tune_option, moo_params, loader=train_loader)

    optimizer = torch.optim.Adam(trainable_params, lr=lr, weight_decay=1e-4)
    #scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=gamma)
    
    for epoch in range(num_epochs):        
        train_loss = train(model, train_loader, criterion, optimizer, device)

    if test_flag:
        test_metrics = test(model, test_evaluator, test_loader, criterion, device)
        val_metrics = test(model, val_evaluator, val_loader, criterion, device)
        return val_metrics, test_metrics
    val_metrics = test(model, val_evaluator, val_loader, criterion, device)
    end_all_time = time.time()
    print(f"Time taken: {end_all_time - st_all_time}")
    return val_metrics



def train(model, train_loader, criterion, optimizer, device):
    total_loss = []

    model.train()
    for batch_idx, (inputs, targets) in enumerate(train_loader):
        optimizer.zero_grad()
        outputs = model(inputs.to(device))

        targets = targets.long().to(device)
        loss = criterion(outputs, targets)

        total_loss.append(loss.item())

        loss.backward()
        optimizer.step()
    
    epoch_loss = sum(total_loss)/len(total_loss)
    return epoch_loss


def test(model, evaluator: Evaluator, data_loader, criterion, device):

    model.eval()
    
    total_loss = []
    y_score = torch.tensor([]).to(device)
    targets_all = torch.tensor([]).to(device)


    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(data_loader):
            outputs = model(inputs.to(device))
            
        
            targets = targets.long().to(device)
            loss = criterion(outputs, targets)
            m = nn.Softmax(dim=1)
            outputs = m(outputs).to(device)
            targets = targets.float().resize_(len(targets), 1)

            total_loss.append(loss.item())
            y_score = torch.cat((y_score, outputs), 0)
            targets_all = torch.cat((targets_all, targets), 0)

        y_score = y_score.detach().cpu().numpy()
        targets_all = targets_all.detach().cpu().numpy()
        auc, acc = evaluator.evaluate(targets_all, y_score)
        
        test_loss = sum(total_loss) / len(total_loss)

        return [auc, acc]

def get_trainable_params(model, tune_option, moo_params, loader=None):
    # get the layer1 parameters and freeze other layers
    if tune_option == 'full':
        return model.parameters()
    
    if tune_option == 'layer_first':
        params_to_update = []
        for name, param in model.named_parameters():
            if 'layer1' in name or 'fc' in name:
                param.requires_grad = True
                params_to_update.append(param)
            else:
                param.requires_grad = False
        return params_to_update
    
    if tune_option == 'layer_last':
        params_to_update = []
        for name, param in model.named_parameters():
            if 'fc' in name:
                param.requires_grad = True
                params_to_update.append(param)
            else:
                param.requires_grad = False
        return params_to_update

    if tune_option == 'moo' or tune_option == 'SMSEMOA':
        params_to_update = get_trainable_weights_names(model, moo_params)
        return params_to_update
    
    if tune_option == 'lrmoo':
        params_to_update = get_lr_weights_names(model, moo_params, default_lr=1e-4)
        return params_to_update

    if tune_option == "auto_rgn":
        params_to_update = get_lr_weights_auto(model, loader, 1e-3)
        return params_to_update

def get_trainable_weights_names(model, layer_requires_grad):
    trained_params = []
    for name, param in model.named_parameters():
        if param.requires_grad:
            if 'conv' in name:
                #print(name, param.data.shape)
                bn_weight = name.replace('conv', 'bn')
                bn_bias = bn_weight.replace('weight', 'bias')
                trained_params.append([name, bn_weight, bn_bias])
    trained_params.append(['fc.weight', 'fc.bias'])
    trained_params_names = []
    for i, param in enumerate(trained_params):
        if layer_requires_grad[i] == 1:
            trained_params_names.extend(param)
    trained_params_for_optim = []
    for name, param in model.named_parameters():
        if name in trained_params_names or 'downsample' in name:
            param.requires_grad = True
            trained_params_for_optim.append(param)
        else:
            param.requires_grad = False
    return trained_params_for_optim

def get_lr_weights_names(model, lr_per_layer, default_lr):
    trained_params = []
    for name, param in model.named_parameters():
        if param.requires_grad:
            if 'conv' in name:
                #print(name, param.data.shape)
                bn_weight = name.replace('conv', 'bn')
                bn_bias = bn_weight.replace('weight', 'bias')
                trained_params.append([name, bn_weight, bn_bias])
    trained_params.append(['fc.weight', 'fc.bias'])
    model_params_dict = dict(model.named_parameters())
    optim_dict_list = []
    for i, param_list in enumerate(trained_params):
        group_params = []
        for param_name in param_list:
            param = model_params_dict[param_name]
            group_params.append(param)
        optim_dict_list.append({'params': group_params, 'lr': lr_per_layer[i]})

    # Set the others to default lr
    down_sample_params = []
    for name, param in model.named_parameters():
        if 'downsample' in name:
            down_sample_params.append(param)
    optim_dict_list.append({'params': down_sample_params, 'lr': default_lr})
    return optim_dict_list

def get_lr_weights_by_norm(model, loader):
    layer_names = [
        n for n, _ in model.named_parameters() if "bn" not in n
    ] 
    metrics = defaultdict(list)
    average_metrics = defaultdict(float)
    partial_loader = itertools.islice(loader, 5)
    xent_grads, entropy_grads = [], []
    for x, y in partial_loader:
        x, y = x.cuda(), y.cuda()
        logits = model(x)

        loss_xent = F.cross_entropy(logits, y)
        grad_xent = torch.autograd.grad(
            outputs=loss_xent, inputs=model.parameters(), retain_graph=True
        )
        xent_grads.append([g.detach() for g in grad_xent])

    def get_grad_norms(model, grads):
        _metrics = defaultdict(list)
        grad_norms, rel_grad_norms = [], []
        for (name, param), grad in zip(model.named_parameters(), grads):
            if name not in layer_names:
                continue
            _metrics[name] = torch.norm(grad).item() / torch.norm(param).item()

        return _metrics

    for xent_grad in xent_grads:
        xent_grad_metrics = get_grad_norms(model, xent_grad)
        for k, v in xent_grad_metrics.items():
            metrics[k].append(v)
    for k, v in metrics.items():
        average_metrics[k] = np.array(v).mean(0)
    return average_metrics


def get_lr_weights_auto(model, loader, lr):
    layer_weights = [0 for layer, _ in model.named_parameters() if 'bn' not in layer]
    weights = get_lr_weights_by_norm(model, loader)
    max_weight = max(weights.values())
    for k, v in weights.items(): 
        weights[k] = v / max_weight
    layer_weights = [sum(x) for x in zip(layer_weights, weights.values())]
    params = defaultdict()
    for n, p in model.named_parameters():
        if "bn" not in n:
            params[n] = p 
    params_weights = []
    for param, weight in weights.items():
        params_weights.append({"params": params[param], "lr": weight*lr})
    return params_weights

#problem = create_random_knapsack_problem(30)

# Get trainable weights names

class MyProblem(Problem):

    def __init__(self, args, n_var=18, n_obj=2):
        super().__init__(n_var=n_var, n_obj=n_obj)
        self.args = args

    def _evaluate(self, x, out):
        x = x.astype(np.uint8)
        #train_moo_remote = ray.remote(train_moo)
        #results = ray.get([train_moo.remote(**self.args) for i in range(x.shape[0])])
        results = ray.get([train_moo.remote(x[i], **self.args) for i in range(x.shape[0])])
        out["F"] = - np.array(results)
        #print(results)


class LR_Problem(Problem):

    def __init__(self, args, n_var=18, n_obj=2):
        super().__init__(n_var=n_var, n_obj=n_obj, xu=1e-3, xl=1e-7)
        self.args = args

    def _evaluate(self, x, out):
        #train_moo_remote = ray.remote(train_moo)
        #results = ray.get([train_moo.remote(**self.args) for i in range(x.shape[0])])
        results = ray.get([train_moo.remote(x[i], **self.args) for i in range(x.shape[0])])
        # write results to file
        out["F"] = - np.array(results)
        
        #print(results)


def run_moo(args):
    n_var = 18 if args["model_flag"] == "resnet18" else 50
    problem = MyProblem(args, n_var=n_var, n_obj=2)
    if args['tune_option'] == 'moo':
        algorithm = NSGA2(
            pop_size=10,
            sampling=BinaryRandomSampling(),
            crossover=TwoPointCrossover(),
            mutation=BitflipMutation(),
            eliminate_duplicates=True)
    elif args['tune_option'] == 'SMSEMOA':
        algorithm = SMSEMOA(
            pop_size=10,
            sampling=BinaryRandomSampling(),
            crossover=TwoPointCrossover(),
            mutation=BitflipMutation(),
            eliminate_duplicates=True)

    res = minimize(problem,
                algorithm,
                ('n_gen', 10),
                verbose=False)
    

    args['test_flag'] = True
    pareto_test_metrics = []
    remote_funcs = []
    for i in range(res.X.shape[0]):
        remote_funcs.append(train_moo.remote(res.X[i], **args))
    results = ray.get(remote_funcs)
    for i in range(res.X.shape[0]):
        pareto_test_metrics.append(results[i][1])
    
    pareto_test_metrics = np.array(pareto_test_metrics)
    wandb_log(args, res, pareto_test_metrics)


def run_lrmoo(args):
    n_var = 18 if args["model_flag"] == "resnet18" else 50
    problem = LR_Problem(args, n_var=n_var, n_obj=2)
    algorithm = NSGA2(pop_size=8, eliminate_duplicates=True)
    res = minimize(problem,
                algorithm,
                ('n_gen', 12),
                verbose=False)

    args['test_flag'] = True
    pareto_test_metrics = []
    remote_funcs = []
    for i in range(res.X.shape[0]):
        remote_funcs.append(train_moo.remote(res.X[i], **args))
    results = ray.get(remote_funcs)
    for i in range(res.X.shape[0]):
        pareto_test_metrics.append(results[i][1])
    
    pareto_test_metrics = np.array(pareto_test_metrics)
    wandb_log(args, res, pareto_test_metrics)


def run_others(args):
    lr_options = [1e-2, 5e-3, 1e-3, 5e-4, 1e-4, 5e-5, 1e-5]
    ray_remotes = []
    for lr in lr_options:
        args["lr"] = lr
        ray_remotes.append(train_moo.remote(None, **args))
    results = ray.get(ray_remotes)
    results = np.array(results)
    best_lr = lr_options[np.argmax(results[:, 0])]
    
    args["lr"] = best_lr
    args["test_flag"] = True
    result = ray.get(train_moo.remote(None, **args))
    
    wandb.init(project="moo_class", tags=[args["data_flag"], args["tune_option"]])
    wandb.config.update(args)
    wandb.log({"test_auc": result[1][0], "test_acc": result[1][1],
               "val_auc": result[0][0], "val_acc": result[0][1]})


def wandb_log(args, res, pareto_test_metrics):
    wandb.init(project="moo_class", tags=[args["data_flag"], args["tune_option"]])
    wandb.config.update(args)
    wandb.log({"test_auc": np.max(pareto_test_metrics[:, 0]),
           "test_acc": np.max(pareto_test_metrics[:, 1]),
           "mean_auc": np.mean(pareto_test_metrics[:, 0]),
           "mean_acc": np.mean(pareto_test_metrics[:, 1]),
           "std_auc": np.std(pareto_test_metrics[:, 0]),
           "std_acc": np.std(pareto_test_metrics[:, 1]),
           "X": str(res.X), "F": res.F})

if __name__ == '__main__':
    ray.init()
    parser = argparse.ArgumentParser(
        description='RUN Baseline model of MedMNIST2D')

    parser.add_argument('--data_flag',
                        default='SRSMAS',
                        type=str)
    parser.add_argument('--num_epochs',
                        default=50,
                        help='num of epochs of training, the script would only test model if set num_epochs to 0',
                        type=int)
    parser.add_argument('--batch_size',
                        default=64,
                        type=int)
    parser.add_argument('--num_workers',
                        default=0,
                        type=int)
    parser.add_argument('--model_flag',
                        default='resnet18',
                        help='choose backbone from resnet18, resnet50',
                        type=str)
    parser.add_argument('--tune_option',
                        default='lrmoo',
                        help='tune option, choose from full, lrmoo, moo')
    parser.add_argument('--test_flag',
                        action="store_true")
    parser.add_argument('--train_frac',
                        default=0.4,
                        type=float)
    parser.add_argument("--lr",
                        default=1e-4,
                        type=float)


    args = parser.parse_args()
    args = vars(args)

    if args["tune_option"] in ["moo", "SMSEMOA"]:
        run_moo(args)
    elif args["tune_option"] == "lrmoo":
        run_lrmoo(args)
    elif args["tune_option"] in ["layer_first", "layer_last", "auto_rgn"]:
        run_others(args)