from copy import deepcopy
from rich import print as pp

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models.feature_extraction import create_feature_extractor
import pytorch_lightning as pl
from timm.optim import create_optimizer_v2, optimizer_kwargs

from models.network import get_network
from utils.loss import SoftTargetCrossEntropy, LabelSmoothingCrossEntropy, SPRegularization
from utils.get_scheduler import get_scheduler
from utils.mixup import Mixup, NoMixup
from utils.transmix import Mixup_transmix
from utils.metrics import AccuracyPL, AP_PL
from utils.bss import BatchSpectralShrinkage
from torchmetrics.classification import MultilabelAveragePrecision
from torchmetrics.utilities.data import dim_zero_cat


def freeze(model, mode='attn'):
    assert mode in ('attn', 'mlp', 'all')
    if mode == 'all':
        pp(f"[purple] freeze all [/purple]")
        _model = deepcopy(model)
        for params in _model.parameters():
            params.requires_grad = False
        return _model
    pp(f"[purple] only {mode} will be trained [/purple]")
    for name_p, p in model.named_parameters():
        if f'.{mode}.' in name_p:
            p.requires_grad = True
        else:
            p.requires_grad = False
    try:
        model.head.weight.requires_grad = True
        model.head.bias.requires_grad = True
    except:
        model.fc.weight.requires_grad = True
        model.fc.bias.requires_grad = True
    try:
        model.pos_embed.requires_grad = True
    except:
        print('no position encoding')
    try:
        for p in model.patch_embed.parameters():
            p.requires_grad = False
    except:
        print('no patch embed')
    return model


class Model(pl.LightningModule):
    
    def __init__(self, args):
        super().__init__()
        self.args = args
        self.model = get_network(args)
        need_l2 = args.l2sp
        need_kd = args.feature_kd or args.msa_kd
        need_guide = args.guide
        make_pretrained = need_guide or need_kd or need_l2
        
        if make_pretrained:
            if self.args.teacher_checkpoint is not None:
                teacher_args = deepcopy(args)
                teacher_args.initial_checkpoint = teacher_args.teacher_checkpoint
                pp(f'[purple] build teacher network from teacher checkpoint : {teacher_args.initial_checkpoint} [/purple]')
                self.pretrained_model = get_network(teacher_args)
                self.pretrained_model = freeze(self.pretrained_model, mode='all')
            else:
                pp('teacher and student networks will be same arch and same weight')
                self.pretrained_model = freeze(self.model, mode='all')
            self.pretrained_model.eval()
        
        if args.feature_kd: 
            get_node_lists = {f'blocks.{i}.add_1': f'layer{i}' for i in range(12)}
            get_node_lists['head'] = 'head'
            self.model = create_feature_extractor(self.model, get_node_lists)
            self.pretrained_model = create_feature_extractor(
                self.pretrained_model, get_node_lists)
        elif args.msa_kd:
            get_node_lists = {f'blocks.{i}.add': f'layer{i}' for i in range(12)}
            get_node_lists['head'] = 'head'
            self.model = create_feature_extractor(self.model, get_node_lists)
            self.pretrained_model = create_feature_extractor(
                self.pretrained_model, get_node_lists)
        else:
            pass
        
        if args.attn_only:
            self.model = freeze(self.model, mode='attn')
        elif args.mlp_only:
            self.model = freeze(self.model, mode='mlp')
        else:
            pass
        
        self.val_loss = nn.CrossEntropyLoss()
        if args.mixup_active:
            self.loss = SoftTargetCrossEntropy()
        elif args.dataset_name == "voc":
            self.loss = nn.BCEWithLogitsLoss()
            self.val_loss = nn.BCEWithLogitsLoss()
        elif args.smoothing:
            self.loss = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
        else:
            self.loss = nn.CrossEntropyLoss()
        
        if args.l2sp:
            self.reg_loss = SPRegularization(self.pretrained_model, self.model)
        pp(f"[green] loss : {self.loss}")
        
        if args.mixup_active:
            pp("[green] Mixup or Cutmix will help you [/green]")
            mixup_args = dict(
                mixup_alpha=args.mixup, cutmix_alpha=args.cutmix,
                cutmix_minmax=args.cutmix_minmax,
                prob=args.mixup_prob, switch_prob=args.mixup_switch_prob,
                mode=args.mixup_mode, label_smoothing=args.smoothing,
                num_classes=args.num_classes)
            self.mixup_fn = Mixup(**mixup_args)
            if args.transmix:
                self.mixup_fn = Mixup_transmix(**mixup_args)
        else:
            self.mixup_fn = NoMixup()
        if args.bss:
            self.bss_module = BatchSpectralShrinkage()
        
        self.build_metrics()
        
    def build_metrics(self):
        self.train_acc = AccuracyPL(topk=(1, 5))
        self.valid_acc = AccuracyPL(topk=(1, 5))
        if self.args.dataset_name == "voc":
            self.train_acc = MultilabelAveragePrecision(
                num_labels=self.args.num_classes,
                average='macro')
            self.valid_acc = AP_PL(
                num_labels=self.args.num_classes,
                average='macro')

    def forward(self, x):
        if self.args.feature_kd or self.args.msa_kd:
            return self.model(x)
        out, attn = self.model(x)
        return out, attn

    def pretrained_forward(self, x):
        if self.args.feature_kd or self.args.msa_kd:
            return self.pretrained_model(x)
        out, attn = self.pretrained_model(x)
        return out, attn

    def train_(self, batch, batch_idx):
        x, y = batch
        x, _y = self.mixup_fn(x, y)
        if self.args.feature_kd or self.args.msa_kd:
            s_out = self.model(x)
            t_out = self.pretrained_model(x)
            logits = s_out['head']
            ce_loss = self.loss(logits, _y)
            attn = 0.0
            
        else: 
            x_ = self.model.forward_features(x)
            part_tokens = x_[0]
            logits, attn = self.model.forward_head(x_)
            if self.args.transmix:
                if isinstance(_y, tuple):
                    cls_attn = attn[:, -1, :, 1:].clone().softmax(dim=-1).mean(dim=1).detach()
                    _y = self.mixup_fn.transmix_label(_y, cls_attn, x.shape)
            ce_loss = self.loss(logits, _y)
        loss = ce_loss
        reg_loss = torch.tensor(0.0, device=self.device)
        bss_loss = torch.tensor(0.0, device=self.device)
        
        if self.args.l2sp:
            reg_loss = self.reg_loss()
            loss += self.args.glambda * reg_loss
        if self.args.feature_kd or self.args.msa_kd:
            for (s_n, s_o), (t_n, t_o) in zip(s_out.items(), t_out.items()):
                total_elements = s_o.numel() * 12
                if 'layer' in s_n and 'layer' in t_n:
                    reg_loss += (s_o - t_o).pow(2).sum() / total_elements
            loss += self.args.glambda * reg_loss
        if self.args.bss:
            bss_loss = self.bss_module(part_tokens)
            loss += self.args.glambda * bss_loss
        return loss, ce_loss, torch.tensor(0.), logits, attn, (x, y), reg_loss, bss_loss

    def train_with_guide(self, batch, batch_idx):
        """
        Custom Train Module for Attention Guide
        """
        _, ce_loss, _, logits, attn, (x, y), reg_loss, bss_loss = self.train_(batch, batch_idx)
        with torch.no_grad():
            _, attn_guide = self.pretrained_forward(x)
        if self.args.guide:
            guide_loss = self.model.guide(attn, attn_guide)
        else:
            raise Exception("Not Implemented")
        loss = ce_loss + (guide_loss * self.args.glambda)
        return loss, ce_loss, guide_loss, logits, attn, (x, y), reg_loss, bss_loss

    def training_step(self, batch, batch_idx=None):
        
        if self.args.guide:
            train = self.train_with_guide
        else:
            train = self.train_
        loss, ce_loss, guide_loss, logits, _, (_, y), reg_loss, bss_loss = train(batch, batch_idx)
        
        if self.args.dataset_name == "voc":
            preds = torch.sigmoid(logits)
            ap = self.train_acc(preds, y)
        else:
            preds = F.softmax(logits, dim=1)
            self.train_acc.update(preds, y)
            ap = 0.0
        return {
            'loss': loss,
            'train_ce_loss': ce_loss,
            'train_guide_loss': guide_loss,
            'train_ap': ap,
            'train_reg_loss': reg_loss,
            'train_bss_loss' : bss_loss,
            }

    def training_epoch_end(self, outputs):
        self.trainer.train_dataloader.sampler.set_epoch(self.current_epoch + 1)
        train_avg_loss = torch.stack([x["loss"] for x in outputs]).mean()
        train_ce_loss = torch.stack([x["train_ce_loss"] for x in outputs]).mean()
        train_guide_loss = torch.stack([x["train_guide_loss"] for x in outputs]).mean()
        train_reg_loss = torch.stack([x["train_reg_loss"] for x in outputs]).mean()
        train_bss_loss = torch.stack([x["train_bss_loss"] for x in outputs]).mean()
        if self.args.dataset_name == "voc":
            train_top1 = torch.stack([x["train_ap"] for x in outputs]).mean()
            train_top5 = 0.0
            train_length = 0.0
        else:
            train_top1, train_top5, train_length = self.train_acc.compute()
        self.logger.log_metrics(
            {"train loss": train_avg_loss,
             "train ce loss": train_ce_loss,
             "train guide_loss": train_guide_loss,
             'train_reg_loss': train_reg_loss,
             'train_bss_loss': train_bss_loss,
             "train top1": train_top1,
             "train top5": train_top5,
             "train total": train_length}, step=self.current_epoch)
        self.train_acc.reset()

    def evaluate(self, batch, stage=None):
        x, y = batch
        if self.args.feature_kd or self.args.msa_kd:
            out = self(x)
            logits = out['head']
        else:
            logits, _ = self(x)
        loss = self.val_loss(logits, y)
        preds = F.softmax(logits, dim=1) if self.args.dataset_name != 'voc' else torch.sigmoid(logits)
        if stage == "val":
            self.valid_acc.update(preds, y)
        return {'loss': loss}

    def validation_step(self, batch, batch_idx=None):
        return self.evaluate(batch, "val")

    def validation_epoch_end(self, outputs):
        valid_avg_loss = torch.stack([x["loss"] for x in outputs]).mean()
        if self.args.dataset_name == "voc":
            valid_top1, valid_top5, valid_length = self.valid_acc.compute(), 0.0, self.valid_acc.total
        else:
            valid_top1, valid_top5, valid_length = self.valid_acc.compute()
        self.logger.log_metrics(
            {"valid loss": valid_avg_loss,
             "valid top1": valid_top1,
             "valid top5": valid_top5,
             "valid total": valid_length}, step=self.current_epoch)
        self.valid_acc.reset()

    def test_step(self, batch, batch_idx=None):
        return self.evaluate(batch, "test")

    def configure_optimizers(self):
        optimizer = create_optimizer_v2(
            self.parameters(),
            **optimizer_kwargs(cfg=self.args))
        lr_scheduler = get_scheduler(
            self.args.sched,
            optimizer,
            self.args.end_epoch,
            self.args.decay_t,
            args=self.args,
            min_lr=self.args.min_lr,
        )
        scheduler_dict = {
            "scheduler": lr_scheduler,
            "interval": "epoch"
        }
        return {"optimizer": optimizer, "lr_scheduler": scheduler_dict}
