import torch
import math
import matplotlib.pyplot as plt

def configure_optimizers(model, learning_rate = 0.001, weight_decay = 0.01, pe_available = True, betas = (0.9, 0.999)):

    """
    Source: https://github.com/karpathy/minGPT/blob/3ed14b2cec0dfdad3f4b2831f2b4a86d11aef150/mingpt/model.py#L136
    """
    # separate out all parameters to those that will and won't experience regularizing weight decay
    decay = set()
    no_decay = set()
    whitelist_weight_modules = (torch.nn.Linear)
    blacklist_weight_modules = (torch.nn.LayerNorm,torch.nn.BatchNorm1d)
    for mn, m in model.named_modules():
        for pn, p in m.named_parameters():
            fpn = '%s.%s' % (mn, pn) if mn else pn # full param name

            if pn.endswith('bias') or pn.endswith('in_proj_bias'):
                # all biases will not be decayed
                no_decay.add(fpn)
            elif pn.endswith('weight') or pn.endswith('in_proj_weight') and isinstance(m, whitelist_weight_modules):
                # weights of whitelist modules will be weight decayed
                decay.add(fpn)
            elif pn.endswith('weight') or pn.endswith('in_proj_weight') and isinstance(m, blacklist_weight_modules):
                # weights of blacklist modules will NOT be weight decayed
                no_decay.add(fpn)


    # special case the position embedding parameter in the root GPT module as not decayed
    if pe_available:
        no_decay.add('pos_enc.pe')

    # validate that we considered every parameter
    param_dict = {pn: p for pn, p in model.named_parameters()}
    inter_params = decay & no_decay
    union_params = decay | no_decay
    assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
    assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
                                                % (str(param_dict.keys() - union_params), )

    # create the pytorch optimizer object
    optim_groups = [
        {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": weight_decay},
        {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
    ]

    optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas)
    return optimizer

class WarmupCosineSchedule(object):

    def __init__(
        self,
        optimizer,
        total_training_steps,
        warmup_steps,
        initial_lr,
        min_lr,
        peak_lr,
    ):
        self.optimizer = optimizer
        self.min_lr = min_lr
        self.peak_lr = peak_lr
        self.warmup_steps = warmup_steps
        self.initial_lr = initial_lr
        self.lr_curve = []

        self._step = -1

        self.total_training_steps = total_training_steps
        self.T_max = self.total_training_steps - self.warmup_steps
        self.lr_increment = (peak_lr - self.initial_lr) / warmup_steps

    def step(self):
        self._step += 1

        if self._step < self.warmup_steps:
                lr = self.initial_lr + self._step * self.lr_increment
        else:
                progress = (self._step - self.warmup_steps) / self.T_max
                lr = self.min_lr + (self.peak_lr - self.min_lr) * 0.5 * (1 + math.cos(math.pi * progress))

        for param_group in self.optimizer.param_groups:
            param_group["lr"] = lr

        self.lr_curve.append(lr)

        return lr

class loss_tracker():

    def __init__(self, objective, ft = False):
        self.running_losses = {key: 0 for key in objective.keys()}
        self.running_losses["combined"] = 0
        self.ft = ft
        self.steps = 0

        self.losses = {key: [] for key in objective.keys()}
        self.losses["combined"] = []

        if self.ft == True:
            self.running_losses.pop("combined")
            self.losses.pop("combined")

    def track(self, key, loss):
        self.running_losses[key] += loss
    
    def step(self):
        self.steps += 1
    
    def update(self):
        for key, loss in self.running_losses.items():
            self.losses[key].append(loss/self.steps)
            self.running_losses[key] = 0
        self.steps = 0
        self.print_losses()

    def get_losses(self):
        return self.losses
    
    def get_loss(self, key):
        return self.losses[key]
    
    def print_losses(self):
        for key, loss in self.losses.items():
            print(f"{key} loss: {loss[-1]:.5f}", end=' | ')
        print("")

    def plot_losses(self):
        for key, loss in self.losses.items():
            plt.plot(loss, label=key)
        plt.legend()
        plt.show()
