import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
import lightly
from lightly.models.modules.heads import SimSiamProjectionHead
from torchmetrics import Accuracy
from src.util.knn_utils import knn_predict
from lightly.utils.benchmarking import OnlineLinearClassifier
from lightly.models.utils import get_weight_decay_parameters

class SSL_Model(pl.LightningModule):
    def __init__(self, args, num_classes):
        super().__init__()
        self.args_ = args
        self.save_hyperparameters("args")
        self.num_classes = num_classes
        self.max_epochs = args.ssl_epochs
        self.projector_dim = args.projector_dim
        self.criterion = args.criterion(args)

        # BACKBONE
        if (args.backbone == 'resnet18'):
            input_dim = 512
            backbone = lightly.models.ResNetGenerator('resnet-18', num_splits=0)
        elif (args.backbone == 'resnet50'):
            input_dim = 2048
            backbone = lightly.models.ResNetGenerator('resnet-50', num_splits=0)
        else:
            raise ValueError("backbone type %s is not supported!" % args.backbone)

        self.backbone = nn.Sequential(
            *list(backbone.children())[:-1],
            nn.AdaptiveAvgPool2d(1),
        )

        # PROJECTOR
        if (args.projection == 'mlp'):
            self.projection_head = SimSiamProjectionHead(input_dim=input_dim, output_dim=self.projector_dim)
        elif (args.projection == 'identity'):
            self.projection_head = nn.Identity()
        elif (args.projection == 'linear'):
            self.projection_head = nn.Linear(input_dim, self.projector_dim)
            
        else:
            raise ValueError("projector type %s is not supported!" % args.projection)

        # ONLINE CLASSIFIER
        if args.online_classifier:
            self.online_classifier = OnlineLinearClassifier(num_classes=num_classes, feature_dim=args.projector_dim)

        # SMALL INIT, IF NEEDED
        if args.small_init:
            self.scale_weights(args.small_init_factor)

    def setup(self, stage):
        if stage == "fit":
            # calculate number of train steps per epoch. weird hack due to https://github.com/Lightning-AI/lightning/issues/10430
            self.train_steps_per_epoch = len(self.trainer._data_connector._train_dataloader_source.dataloader()) // (self.args_.devices * self.args_.num_nodes)

    def forward(self, x):
        x = self.backbone(x).flatten(start_dim=1)
        x = self.projection_head(x)
        return x
    
    def training_step(self, batch, batch_idx):
        (x0, x1), targets, _ = batch
        z_a = self.forward(x0)
        z_b = self.forward(x1)

        # collect all batches together. needed for complete pairwise interaction
        all_z_a = self.all_gather(z_a, sync_grads=True)
        all_z_a = all_z_a.view(-1, z_a.shape[1])
        all_z_b = self.all_gather(z_b, sync_grads=True)
        all_z_b = all_z_b.view(-1, z_a.shape[1])

        loss = self.criterion(all_z_a, all_z_b)
        self.log('train_loss', loss, batch_size=len(targets))

        # log eigenvalues
        if self.args_.log_eigenvalues:
            centered_a = all_z_a - all_z_a.mean(dim=0)
            centered_covariance = (centered_a @ centered_a.T) / (centered_a.shape[0])
            cov_eigs = torch.linalg.eigvalsh(centered_covariance)
            sorted_eigs = torch.sort(cov_eigs, descending=True).values

            logging_dict = {
                f'cov_eig_{i}': sorted_eigs[i] for i in range(20)
            }
            logging_dict['step'] = self.global_step

            self.log('cov_eigs', logging_dict, sync_dist=True)

        # online linear classifier
        if self.args_.online_classifier:
            features = torch.cat([z_a, z_b], dim=0)
            cls_loss, cls_log = self.online_classifier.training_step(
                (features.detach(), targets.repeat(2)), batch_idx
            )
            self.log_dict(cls_log, sync_dist=True, batch_size=len(targets))
        else:
            cls_loss = torch.tensor(0.0)

        return loss+cls_loss

    def validation_step(self, batch, batch_idx):
        if self.args_.online_classifier:
            images, targets, _ = batch
            features = self.forward(images).flatten(start_dim=1)
            cls_loss, cls_log = self.online_classifier.validation_step(
                (features.detach(), targets), batch_idx
            )
            self.log_dict(cls_log, prog_bar=True, sync_dist=True, batch_size=len(targets))
        else:
            cls_loss = torch.tensor(0.0)
            self.log('val_online_cls_top1', cls_loss, sync_dist=True)

        return cls_loss

    def configure_optimizers(self):
        optim = self.setup_optimizer()
        scheduler = self.setup_scheduler(optim)
        return [optim], [scheduler]
    
    # make optimizer
    def setup_optimizer(self):
        if self.args_.ssl_optimizer == 'SGD':
            params, params_no_weight_decay = get_weight_decay_parameters(
                [self.backbone, self.projection_head]
            )
            
            if self.args_.online_classifier:
                optim = torch.optim.SGD([{"name": "regular_params", "params": params},
                                {"name": "nodecay_params", "params": params_no_weight_decay, "weight_decay": 0},
                                {"name": "online_classifier", "params": self.online_classifier.parameters(), "weight_decay": 0}
                            ], 
                            lr=self.args_.ssl_learning_rate, 
                            momentum=0.9, 
                            weight_decay=self.args_.ssl_weight_decay
                        )

            else:
                optim = torch.optim.SGD(self.parameters(),
                                        lr=self.args_.ssl_learning_rate, 
                                        momentum=0.9, 
                                        weight_decay=self.args_.ssl_weight_decay
                                    )

        elif self.args_.ssl_optimizer == 'Adam':
            optim = torch.optim.Adam(self.parameters(), lr=self.args_.ssl_learning_rate, weight_decay=self.args_.ssl_weight_decay)
        else:
            raise ValueError("optimizer type %s is not supported!" % self.args_.ssl_optimizer)
        
        return optim
    
    def setup_scheduler(self, optim):
        # make scheduler
        if self.args_.ssl_lr_scheduler == 'cosineannealing':
            # end cosine annealing cycle on last epoch
            tmax = int(self.args_.ssl_epochs * self.train_steps_per_epoch)

            scheduler = {
                'scheduler': torch.optim.lr_scheduler.CosineAnnealingLR(optim, T_max=tmax, eta_min=1e-5),
                'interval': 'step'
            }

        elif self.args_.ssl_lr_scheduler == 'cosinewarmlr':
            # how many iterations before cosine annealing restart
            t_0 = int(self.args_.ssl_warm_restart_epochs * self.train_steps_per_epoch)

            scheduler = {
                'scheduler': torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer=optim, T_0=t_0, eta_min=1e-5),
                'interval': 'step'
            }

        elif self.args_.ssl_lr_scheduler == 'steplr':
            scheduler = torch.optim.lr_scheduler.StepLR(optim, step_size=50, gamma=0.8)
        else:
            raise ValueError("lr scheduler type %s is not supported!" % self.args_.ssl_lr_scheduler)
        

        return scheduler
    
    # adapted from "On the Stepwise Nature of Self-Supervised Learning" implementation
    @torch.no_grad()
    def scale_weights(self, scale: float = 1):
        print(f"scaling weights by factor of {scale}")
        for name, param in self.named_parameters():
            if "weight" in name and scale != 1:
                print(f"scaling {name} by factor of {scale}")
                transformed_param = param * scale
                param.copy_(transformed_param)


class Linear_Probe(pl.LightningModule):
    def __init__(self, args, backbone, num_classes):
        super().__init__()
        self.args_ = args
        self.save_hyperparameters("args")

        if (args.backbone == 'resnet18'):
            self.layer = nn.Linear(512, num_classes)
        elif (args.backbone == 'resnet50'):
            self.layer = nn.Linear(2048, num_classes)
        else:
            raise ValueError("backbone type %s is not supported!" % args.backbone)

        self.backbone = backbone
        for param in self.backbone.parameters():
            param.requires_grad = False

        self.criterion = nn.CrossEntropyLoss()
        self.accuracy_metric = Accuracy(task='multiclass', num_classes=num_classes)

        if (num_classes >= 5):
            self.accuracy_metric_top5 = Accuracy(task='multiclass', num_classes=num_classes, top_k=5)

        self.max_accuracy = 0

    def forward(self, x):
        #x = self.backbone(x).flatten(start_dim=1)
        x = self.layer(x.detach())
        return x

    def training_step(self, batch, batch_idx):
        x, y, _ = batch

        preds = self.forward(x)

        loss = self.criterion(preds, y)

        self.log('linprobe_train_loss', loss)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx): 
        x, y, _ = batch 

        preds = self.forward(x) 

        acc =  self.accuracy_metric(preds, y)
        self.log('validation_accuracy', acc, on_epoch=True, prog_bar=True, sync_dist=True, batch_size=self.args_.batch_size)

        if self.accuracy_metric_top5:
            acc_top5 = self.accuracy_metric_top5(preds, y)
            self.log('validation_accuracy_top5', acc_top5, on_epoch=True, sync_dist=True, batch_size=self.args_.batch_size)

        return acc.item()
    
    def test_step(self, batch, batch_idx): 
        x, y, _ = batch

        preds = self.forward(x) 

        acc =  self.accuracy_metric(preds, y)
        self.log('test_accuracy', acc, on_epoch=True, sync_dist=True, batch_size=self.args_.batch_size)

        if self.accuracy_metric_top5:
            acc_top5 = self.accuracy_metric_top5(preds, y)
            self.log('test_accuracy_top5', acc_top5, on_epoch=True, sync_dist=True, batch_size=self.args_.batch_size)

        return acc.item()

    def configure_optimizers(self):
        if self.args_.probe_optimizer == 'SGD':
            optim = torch.optim.SGD(self.parameters(), lr=self.args_.probe_learning_rate, momentum=0.9, weight_decay=0)
        elif self.args_.probe_optimizer == 'Adam':
            optim = torch.optim.Adam(self.parameters(), lr=self.args_.probe_learning_rate, weight_decay=0)
        else:
            raise ValueError("optimizer type %s is not supported!" % self.args_.ssl_optimizer)
        
        scheduler = {
            'scheduler': torch.optim.lr_scheduler.StepLR(optim, step_size=10, gamma=0.75, ),
            'name': f'probe-{self.args_.probe_optimizer}-lr'
        }

        return [optim], [scheduler]
