from contextlib import nullcontext
import torch
import torch.nn as nn
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.optim.lr_scheduler import LinearLR
from torch.optim.lr_scheduler import SequentialLR

from model_gpt_rk1base import ModelGPTRK1B


class Trainer:
    def __init__(self, args):
        self.args = args

    def __call__(self, *args, **kwargs):
        return self.run(*args, **kwargs)

    def epoch(self):
        self.iters = 0

    def init(self, model):
        self.iters = 0

        if self.args.device == 'cpu':
            self.ctx = nullcontext()
        else:
            self.ctx = torch.amp.autocast(
                device_type=self.args.device_type,
                dtype=torch.float16)

        self.optimizer = torch.optim.AdamW(model.parameters(),
            lr=self.args.lr, betas=(0.9, 0.98), eps=1.E-9)

        scheduler_warmup = LinearLR(self.optimizer,
            total_iters=self.args.warmup_steps)
        scheduler_decay = CosineAnnealingLR(self.optimizer,
            T_max=self.args.epochs-self.args.warmup_steps,
            eta_min=self.args.lr_min)
        self.scheduler = SequentialLR(self.optimizer,
            milestones=[self.args.warmup_steps],
            schedulers=[scheduler_warmup, scheduler_decay])

        self.scaler = torch.cuda.amp.GradScaler(enabled=True)

    def run(self, model, data_x, data_y):
        self.iters += 1
        
        self.device = next(model.parameters()).device
        data_x = data_x.to(self.device)
        if isinstance(data_y, list):
            data_y = [data_y_curr.to(self.device) for data_y_curr in data_y]
        else:
            data_y = data_y.to(self.device)
        
        model.train()

        loss = 0

        with self.ctx:
            if isinstance(model, ModelGPTRK1B): # ModelGPTRK1B
                
                x, model_heads, target, dimension = model(data_x, data_y)

                B, T, C = x.shape
                d = x.detach()
                d.requires_grad = True
                for i in range(len(model_heads)):
                    p = model_heads[i](d).reshape(B*T, -1)
                    p = nn.functional.log_softmax(p, dim=1)
                    ls = torch.gather(p, dim=1, index=target[i].reshape(-1, 1)).squeeze(1)
                    ls = -1. * torch.mean(ls) / dimension
                    ls = ls / self.args.grad_acc_steps
                    self.scaler.scale(ls).backward()
                    loss += ls.item()
                x.backward(gradient=d.grad)
            else:
                _, loss = model(data_x, data_y)
                loss = loss / self.args.grad_acc_steps
                self.scaler.scale(loss).backward()

        if self.iters % self.args.grad_acc_steps == 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
            self.scaler.step(self.optimizer)
            self.scaler.update()
            self.optimizer.zero_grad(set_to_none=True)
        
            # self.scheduler.step()

        if isinstance(loss, float):
            return loss
        return loss.item()

    def test(self, model, data_x, data_y):
        self.device = next(model.parameters()).device
        data_x = data_x.to(self.device)
        if isinstance(data_y, list):
            data_y = [data_y_curr.to(self.device) for data_y_curr in data_y]
        else:
            data_y = data_y.to(self.device)

        model.eval()
        
        with torch.inference_mode():
            with self.ctx:
                if isinstance(model, ModelGPTRK1B): # ModelGPTRK1B
                    loss = 0
                    x, model_heads, target, dimension = model(data_x, data_y)

                    B, T, C = x.shape
                    d = x
                    for i in range(len(model_heads)):
                        p = model_heads[i](d).reshape(B*T, -1)
                        p = nn.functional.log_softmax(p, dim=1)
                        ls = torch.gather(p, dim=1, index=target[i].reshape(-1, 1)).squeeze(1)
                        ls = -1. * torch.mean(ls) / dimension
                        ls = ls / self.args.grad_acc_steps
                        loss += ls.item()
                    
                else:
                    _, loss = model(data_x, data_y, with_w_norm=False)
        
        model.train()
        
        if isinstance(loss, float):
            return loss
        return loss.item()
