import torch
import torch.nn as nn
from torch.nn.functional import softmax
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
import optuna
import copy
import os
import json
from utils import set_global_seed
from cifar10_training import get_optimizer
import models
import search_spaces
import argparse


class DistApproximation(nn.Module):
    def __init__(self, d, seed):
        super().__init__()
        theta = torch.randn(d, generator=torch.Generator().manual_seed(seed))
        self.theta = nn.Parameter(theta, requires_grad=True)

    def forward(self):
        return softmax(self.theta)


def KL(p, q):
    '''
    KL-divergence between categorical distributions
    
    :param p: Parameters of categorical distriburion
    :param q: Parameters of categorical distriburion
    '''
    return (p * torch.log(p) - p * torch.log(q)).sum()

def CrossEntropy(p, q):
    return (- p * torch.log(q)).sum()

def running_avg(losses):
    return np.cumsum(losses) / np.arange(1, len(losses) + 1)

def get_data(d=1000, alpha=1.0, power=1.0):
    '''
    Generate categorical distribution param
    The probability of class k is proportional to 
        alpha / k + (1 - alpha)
    
    :param d: Size of random vector
    :param alpha:
        - =1.0 is heavy tail distribution
        - =0.0 is uniform distribution 
    '''
    p = 1 / torch.arange(1, d+1) ** power
    p /= p.sum()
    p = alpha * p + (1 - alpha) / d
    return p


def train(num_epoches, model, p, optimizer, device='cpu', trial=None, clipping=None, scheduler=None):
    losses = []
    p = p.to(device)
    for epoch in range(num_epoches):
        optimizer.zero_grad()
        q = model()
        loss = CrossEntropy(p, q)
        with torch.no_grad():
            kl = KL(p, q)
        loss.backward()
        if clipping is not None:
            torch.nn.utils.clip_grad_norm_(model.parameters(), clipping, norm_type='inf')
        optimizer.step()
        if scheduler is not None:
            scheduler.step()
        if trial is not None:
            trial.report(loss.item(), epoch)
            if trial.should_prune():
                raise optuna.exceptions.TrialPruned()
        losses.append(loss.item())
        print(f"Epoch [{epoch+1}/{num_epoches}], Loss: {loss.item():.4f}, grad_norm: {model.theta.grad.max().item()}")
    return losses


def suggest_params(trial, search_space, prefix=None):
    params = {}
    for param in search_space:
        if isinstance(search_space[param], dict):
            if prefix is not None:
                param_name = prefix +  '_' + param
            else:
                param_name = param
            if search_space[param]['type'] == 'float':
                params[param] = trial.suggest_float(param_name, search_space[param]['min'], search_space[param]['max'], log=search_space[param]['log'])
            else:
                params[param] = trial.suggest_int(param_name, search_space[param]['min'], search_space[param]['max'], log=search_space[param]['log'])
    return params

def tune(n_trials, search_space, optimizer_name, ModelCls, device, num_epoches, n_startup_trials, path, seed=42, use_augmentations=False):
    def objective(trial):
        set_global_seed(seed)
        model = ModelCls(search_space['d'], seed).to(device)
        p = get_data(d=search_space['d'], alpha=search_space['alpha'], power=search_space['power'])
        n_iters = num_epoches
        optimizer, (clipping, scheduler) = get_optimizer(optimizer_name, model, search_space, trial=trial, n_iters=n_iters)
        losses = train(num_epoches=num_epoches, 
                                  model=model,
                                  p=p, 
                                  optimizer=optimizer, 
                                  device=device,
                                  clipping=clipping,
                                  scheduler=scheduler
        )
        minimum = running_avg(np.log(losses))[-1]
        if trial.number == 0 or trial.study.best_value > minimum:
            result = trial.params | {'loss': minimum, 'trial': trial.number}
            save_result(result, path, optimizer_name)
        return minimum
    
    pruner = optuna.pruners.MedianPruner()
    sampler = optuna.samplers.TPESampler(seed=seed, n_startup_trials=n_startup_trials)

    os.makedirs(os.path.join(path, 'study'), exist_ok=True)
    storage_path = f"sqlite:///{os.path.join(path, 'study', optimizer_name + '.db')}"

    study = optuna.create_study(
        study_name=f"{optimizer_name}",
        storage=storage_path,
        direction="minimize",
        pruner=pruner,
        sampler=sampler,
        load_if_exists=True
    )
    study.optimize(objective, n_trials=n_trials)
    best_trial = study.best_trial

    set_global_seed(seed)
    model = ModelCls(search_space['d'], seed).to(device)
    p = get_data(d=search_space['d'], alpha=search_space['alpha'])
    n_iters = num_epoches
    opt_params = best_trial.params
    opt_params['momentum'] = 0
    opt_params['weight_decay'] = 0
    optimizer, (clipping, scheduler) = get_optimizer(optimizer_name, model, search_space, trial=None, optimizer_params=opt_params, n_iters=n_iters)
    losses = train(num_epoches=num_epoches, 
                                  model=model,
                                  p=p, 
                                  optimizer=optimizer, 
                                  device=device,
                                  clipping=clipping,
                                  scheduler=scheduler
    )
    minimum = running_avg(np.log(losses))[-1]
    result = best_trial.params | {'val_score': minimum}
    return {key: float(value) for key, value in result.items()}


def get_arguments():
    parser = argparse.ArgumentParser()
    parser.add_argument('--optimizer', type=str, default='AdamW', choices=['Muon',
                                                                           'SGD-approx', 'Signum-approx',
                                                                           'SignumDL', 'SignumDLNesterov',
                                                                           'Signum', 'SignumLinearLR',
                                                                           'Signum_decoupled_wd', 'Signum_decoupled_wd_LinearLR',
                                                                           'AdamW', 'Adam', 'AdamEps', 
                                                                           'AdamWClip', 'AdamEpsScheduling', 'AdamBetaScheduling', 'AdamPaLM2', 
                                                                           'SGD', 'SGDLinearLR', 'SGDCosineAnnealingLR', 'SGDClip', 'SGDLinearLR+Clip',
                                                                           'SoftSignumPT', 'SoftSignumPT-auto', 'SoftSignumPT-const',
                                                                           'SoftSignumPT_not_decoupled_wd', 'SoftSignumPT_not_decoupled_wd-auto', 'SoftSignumPT_not_decoupled_wd-const',
                                                                           'SoftSignumSGD', 'SoftSignumSGD-auto', 'SoftSignumSGD-const',
                                                                           'SoftSignumSGD_not_decoupled_wd', 'SoftSignumSGD_not_decoupled_wd-auto', 'SoftSignumSGD_not_decoupled_wd-const',
                                                                           'Signum+SGD', 'Signum+SGD_not_decoupled_wd', 'Signum+SGD_with_LinearLR'],
                        help='Optimizer to use for training.')
    parser.add_argument('--user', type=str, required=True, choices=['user1', 'user2', 'user3'],
                        help='Name of the user folder for saving the results.')
    parser.add_argument('--dist_dim', type=int, default=1000, help='The number of dimensions in categorical distribution')
    parser.add_argument('--dist_power', type=float, default=1.0)
    parser.add_argument('--heavy_tail_param', type=float, default=1.0, help='1.0 is heavy tail distribution, 0.0 is uniform distribution')
    parser.add_argument('--n_trials', type=int, default=50,
                        help='Number of Optuna trials to run.')
    parser.add_argument('--n_startup_trials', type=int, default=20,
                        help='Number of Optuna startup trials to run.')
    parser.add_argument('--max_epochs', type=int, default=50,
                        help='Maximum number of epochs for training each trial.')
    parser.add_argument('--device', type=str, default='cuda:0' if torch.cuda.is_available() else 'cpu',
                        help='Device to use for training (e.g., "cpu", "cuda:0").')
    parser.add_argument('--use_augmentations', action="store_true", help="Use augmentations")
    args = parser.parse_args()
    return args


def save_result(
    result: dict,
    path: str,
    optimizer_name: str
):
    os.makedirs(path, exist_ok=True)
    with open(f'{path}/{optimizer_name}.json', 'w') as f:
        json.dump(result, f)


if __name__ == '__main__':
    args = get_arguments()
    path = f'tuning/{args.user}/heavy_tail_data_alpha_{args.heavy_tail_param}_pow_{args.dist_power}-final-final'
    search_space = search_spaces.search_spaces_map[args.optimizer]
    search_space['d'] = args.dist_dim
    search_space['alpha'] = args.heavy_tail_param
    search_space['power'] = args.dist_power
    search_space['momentum'] = 0
    search_space['weight_decay'] = 0
    search_space['only_sign_iters'] = 0.8
    search_space['warmup_iters'] = 0.15
    results = tune(
        n_trials=args.n_trials,
        search_space=search_space,
        optimizer_name=args.optimizer,
        ModelCls=DistApproximation,
        device=args.device,
        num_epoches=args.max_epochs,
        n_startup_trials=args.n_startup_trials,
        use_augmentations=args.use_augmentations,
        path=path
    )

    save_result(results, path, args.optimizer)