import argparse
import itertools
import os
import time
from collections import OrderedDict, defaultdict
from copy import deepcopy

import numpy as np
import PIL
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data
import torchvision.transforms as transforms
from evaluator import Evaluator
from torchvision.models.resnet import resnet18, ResNet18_Weights, resnet50, ResNet50_Weights
import torchvision
from tqdm import trange

from pymoo.algorithms.moo.nsga2 import NSGA2
from pymoo.algorithms.soo.nonconvex.cmaes import CMAES
from pymoo.algorithms.moo.sms import SMSEMOA
from pymoo.operators.crossover.pntx import TwoPointCrossover
from pymoo.operators.mutation.bitflip import BitflipMutation
from pymoo.operators.sampling.rnd import BinaryRandomSampling
from pymoo.optimize import minimize
from pymoo.core.problem import Problem
import numpy as np
import ray
import wandb
from dataset import PickleDataset, HAM10000DatasetBalanced


torch.manual_seed(0)

@ray.remote(num_gpus=0.5)
def train_moo_with_optimizer(optimizer_params, data_flag, num_epochs, batch_size, model_flag, tune_option, test_flag, train_frac, num_workers, optimizer_type):
    
    st_all_time = time.time()
    
    device = torch.device('cuda') 
    
    print('==> Preparing data...')
    
    if data_flag == "HAM10000":
        dataset = PickleDataset(f"cached_data/{data_flag}.pkl")
    else:
        dataset = PickleDataset(f"cached_data/{data_flag}.pkl")
    n_classes = len(dataset.classes)
    print(f"Dataset has {len(dataset)} images and {n_classes} classes")
    train_len = int(len(dataset) * train_frac)
    val_len = int(len(dataset) * 0.2)
    test_len = len(dataset) - train_len - val_len
    train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_len, val_len, test_len], generator=torch.Generator().manual_seed(41))
    
    train_loader = data.DataLoader(dataset=train_dataset,
                                batch_size=batch_size,
                                shuffle=True,
                                num_workers=num_workers,
                                pin_memory=True)
    val_loader = data.DataLoader(dataset=val_dataset,
                                batch_size=512,
                                shuffle=False, 
                                num_workers=num_workers,
                                pin_memory=True)
    test_loader = data.DataLoader(dataset=test_dataset,
                                batch_size=512,
                                shuffle=False,
                                num_workers=num_workers,
                                pin_memory=True)
    print('==> Building and training model...')
    
    if model_flag == 'resnet18':
        model =  resnet18(weights=ResNet18_Weights.DEFAULT)
    elif model_flag == 'resnet50':
        model =  resnet50(weights=ResNet50_Weights.DEFAULT)
    else:
        raise NotImplementedError
    model.fc = nn.Linear(model.fc.in_features, n_classes)
    model = model.to(device)
    model.train()

    train_evaluator = Evaluator(data_flag, 'train')
    val_evaluator = Evaluator(data_flag, 'val')
    test_evaluator = Evaluator(data_flag, 'test')

    criterion = nn.CrossEntropyLoss()

    if num_epochs == 0:
        return
    
    # Get trainable parameters with layer-wise learning rates
    trainable_params = get_lr_weights_with_optimizer(model, optimizer_params, optimizer_type)

    # Create optimizer based on type
    if optimizer_type == 'adamw':
        optimizer = torch.optim.AdamW(trainable_params)
    elif optimizer_type == 'rmsprop':
        optimizer = torch.optim.RMSprop(trainable_params)
    else:
        raise NotImplementedError(f"Optimizer {optimizer_type} not supported")
    
    for epoch in range(num_epochs):        
        train_loss = train(model, train_loader, criterion, optimizer, device)

    if test_flag:
        test_metrics = test(model, test_evaluator, test_loader, criterion, device)
        val_metrics = test(model, val_evaluator, val_loader, criterion, device)
        return val_metrics, test_metrics
    val_metrics = test(model, val_evaluator, val_loader, criterion, device)
    end_all_time = time.time()
    print(f"Time taken: {end_all_time - st_all_time}")
    return val_metrics


def train(model, train_loader, criterion, optimizer, device):
    total_loss = []

    model.train()
    for batch_idx, (inputs, targets) in enumerate(train_loader):
        optimizer.zero_grad()
        outputs = model(inputs.to(device))

        targets = targets.long().to(device)
        loss = criterion(outputs, targets)

        total_loss.append(loss.item())

        loss.backward()
        optimizer.step()
    
    epoch_loss = sum(total_loss)/len(total_loss)
    return epoch_loss


def test(model, evaluator: Evaluator, data_loader, criterion, device):

    model.eval()
    
    total_loss = []
    y_score = torch.tensor([]).to(device)
    targets_all = torch.tensor([]).to(device)

    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(data_loader):
            outputs = model(inputs.to(device))
            
            targets = targets.long().to(device)
            loss = criterion(outputs, targets)
            m = nn.Softmax(dim=1)
            outputs = m(outputs).to(device)
            targets = targets.float().resize_(len(targets), 1)

            total_loss.append(loss.item())
            y_score = torch.cat((y_score, outputs), 0)
            targets_all = torch.cat((targets_all, targets), 0)

        y_score = y_score.detach().cpu().numpy()
        targets_all = targets_all.detach().cpu().numpy()
        auc, acc = evaluator.evaluate(targets_all, y_score)
        
        test_loss = sum(total_loss) / len(total_loss)

        return [auc, acc]


def get_lr_weights_with_optimizer(model, optimizer_params, optimizer_type):
    """
    Get trainable parameters with layer-wise learning rates and optimizer hyperparameters
    optimizer_params contains [layer_lrs..., optimizer_hyperparams...]
    """
    # Get layer structure same as original
    trained_params = []
    for name, param in model.named_parameters():
        if param.requires_grad:
            if 'conv' in name:
                bn_weight = name.replace('conv', 'bn')
                bn_bias = bn_weight.replace('weight', 'bias')
                trained_params.append([name, bn_weight, bn_bias])
    trained_params.append(['fc.weight', 'fc.bias'])
    
    model_params_dict = dict(model.named_parameters())
    optim_dict_list = []
    
    # Extract layer learning rates and optimizer hyperparams
    num_layers = len(trained_params)
    layer_lrs = optimizer_params[:num_layers]
    optimizer_hyperparams = optimizer_params[num_layers:]
    
    # Create parameter groups with layer-wise learning rates
    for i, param_list in enumerate(trained_params):
        group_params = []
        for param_name in param_list:
            if param_name in model_params_dict:
                param = model_params_dict[param_name]
                group_params.append(param)
        
        # Base parameter group with learning rate
        param_group = {'params': group_params, 'lr': layer_lrs[i]}
        
        # Add optimizer-specific hyperparameters
        if optimizer_type == 'adamw':
            # optimizer_hyperparams = [beta1, beta2, weight_decay, eps]
            param_group.update({
                'betas': (optimizer_hyperparams[0], optimizer_hyperparams[1]),
                'weight_decay': optimizer_hyperparams[2],
                'eps': optimizer_hyperparams[3]
            })
        elif optimizer_type == 'rmsprop':
            # optimizer_hyperparams = [alpha, eps, weight_decay, momentum]
            param_group.update({
                'alpha': optimizer_hyperparams[0],
                'eps': optimizer_hyperparams[1],
                'weight_decay': optimizer_hyperparams[2],
                'momentum': optimizer_hyperparams[3]
            })
        
        optim_dict_list.append(param_group)

    # Handle downsample layers with default learning rate
    down_sample_params = []
    for name, param in model.named_parameters():
        if 'downsample' in name:
            down_sample_params.append(param)
    
    if down_sample_params:
        param_group = {'params': down_sample_params, 'lr': 1e-4}
        if optimizer_type == 'adamw':
            param_group.update({
                'betas': (optimizer_hyperparams[0], optimizer_hyperparams[1]),
                'weight_decay': optimizer_hyperparams[2],
                'eps': optimizer_hyperparams[3]
            })
        elif optimizer_type == 'rmsprop':
            param_group.update({
                'alpha': optimizer_hyperparams[0],
                'eps': optimizer_hyperparams[1],
                'weight_decay': optimizer_hyperparams[2],
                'momentum': optimizer_hyperparams[3]
            })
        optim_dict_list.append(param_group)
    
    return optim_dict_list


class OptimizerProblem(Problem):
    """Multi-objective optimization problem for layer-wise learning rates and optimizer hyperparameters"""
    
    def __init__(self, args, optimizer_type, n_var=None, n_obj=2):
        # Calculate number of variables: n_layers + optimizer_hyperparams
        n_layers = 18 if args["model_flag"] == "resnet18" else 50
        
        if optimizer_type == 'adamw':
            n_optimizer_params = 4  # beta1, beta2, weight_decay, eps
        elif optimizer_type == 'rmsprop':
            n_optimizer_params = 4  # alpha, eps, weight_decay, momentum
        else:
            raise NotImplementedError(f"Optimizer {optimizer_type} not supported")
        
        n_var = n_layers + n_optimizer_params
        
        # Define bounds for variables
        xl = np.zeros(n_var)
        xu = np.zeros(n_var)
        
        # Learning rate bounds for layers
        xl[:n_layers] = 1e-7
        xu[:n_layers] = 1e-3
        
        # Optimizer-specific hyperparameter bounds
        if optimizer_type == 'adamw':
            # beta1, beta2, weight_decay, eps
            xl[n_layers:] = [0.8, 0.9, 1e-6, 1e-10]
            xu[n_layers:] = [0.999, 0.9999, 1e-2, 1e-6]
        elif optimizer_type == 'rmsprop':
            # alpha, eps, weight_decay, momentum
            xl[n_layers:] = [0.9, 1e-10, 1e-6, 0.0]
            xu[n_layers:] = [0.999, 1e-6, 1e-2, 0.99]
        
        super().__init__(n_var=n_var, n_obj=n_obj, xl=xl, xu=xu)
        self.args = args
        self.optimizer_type = optimizer_type

    def _evaluate(self, x, out):
        # Remove optimizer_type from args to avoid duplicate argument error
        args_copy = self.args.copy()
        args_copy.pop('optimizer_type', None)
        results = ray.get([train_moo_with_optimizer.remote(x[i], optimizer_type=self.optimizer_type, **args_copy) for i in range(x.shape[0])])
        out["F"] = -np.array(results)


def run_optimizer_moo(args, optimizer_type):
    """Run MOO optimization for specific optimizer type"""
    problem = OptimizerProblem(args, optimizer_type, n_obj=2)
    algorithm = NSGA2(pop_size=8, eliminate_duplicates=True)
    
    res = minimize(problem,
                algorithm,
                ('n_gen', 12),
                verbose=False)

    # Test phase
    args['test_flag'] = True
    # Remove optimizer_type from args to avoid duplicate argument error
    args_copy = args.copy()
    args_copy.pop('optimizer_type', None)
    pareto_test_metrics = []
    remote_funcs = []
    for i in range(res.X.shape[0]):
        remote_funcs.append(train_moo_with_optimizer.remote(res.X[i], optimizer_type=optimizer_type, **args_copy))
    results = ray.get(remote_funcs)
    for i in range(res.X.shape[0]):
        pareto_test_metrics.append(results[i][1])
    
    pareto_test_metrics = np.array(pareto_test_metrics)
    wandb_log(args, res, pareto_test_metrics, optimizer_type)


def wandb_log(args, res, pareto_test_metrics, optimizer_type):
    """Log results to wandb with optimizer type information"""
    wandb.init(project="moo_class", tags=[args["data_flag"], f"optimizer_{optimizer_type}"])
    wandb.config.update(args)
    wandb.config.update({"optimizer_type": optimizer_type})
    wandb.log({"test_auc": np.max(pareto_test_metrics[:, 0]),
           "test_acc": np.max(pareto_test_metrics[:, 1]),
           "mean_auc": np.mean(pareto_test_metrics[:, 0]),
           "mean_acc": np.mean(pareto_test_metrics[:, 1]),
           "std_auc": np.std(pareto_test_metrics[:, 0]),
           "std_acc": np.std(pareto_test_metrics[:, 1]),
           "X": str(res.X), "F": res.F})


if __name__ == '__main__':
    ray.init()
    parser = argparse.ArgumentParser(
        description='MOO optimization with different optimizers')

    parser.add_argument('--data_flag',
                        default='SRSMAS',
                        type=str)
    parser.add_argument('--num_epochs',
                        default=50,
                        help='num of epochs of training',
                        type=int)
    parser.add_argument('--batch_size',
                        default=64,
                        type=int)
    parser.add_argument('--num_workers',
                        default=0,
                        type=int)
    parser.add_argument('--model_flag',
                        default='resnet18',
                        help='choose backbone from resnet18, resnet50',
                        type=str)
    parser.add_argument('--optimizer_type',
                        default='adamw',
                        help='choose optimizer: adamw, rmsprop',
                        type=str)
    parser.add_argument('--test_flag',
                        action="store_true")
    parser.add_argument('--train_frac',
                        default=0.4,
                        type=float)

    args = parser.parse_args()
    args = vars(args)
    
    # Force tune_option to lrmoo for this implementation
    args['tune_option'] = 'lrmoo'

    if args["optimizer_type"] in ["adamw", "rmsprop"]:
        run_optimizer_moo(args, args["optimizer_type"])
    else:
        print(f"Optimizer {args['optimizer_type']} not supported. Choose from: adamw, rmsprop") 