import copy

import pytorch_lightning as pl
import torch
import torchvision
from torch import nn

from methods.base import Base

from lightly.loss import NegativeCosineSimilarity
from lightly.models.modules import BYOLPredictionHead, BYOLProjectionHead
from lightly.models.utils import deactivate_requires_grad, update_momentum
from lightly.transforms.byol_transform import (
    BYOLTransform,
    BYOLView1Transform,
    BYOLView2Transform,
)
from lightly.utils.scheduler import cosine_schedule
from lightly.models.utils import get_weight_decay_parameters, update_momentum
from lightly.transforms import BYOLTransform
from lightly.utils.benchmarking import OnlineLinearClassifier
from lightly.utils.lars import LARS
from lightly.utils.scheduler import CosineWarmupScheduler, cosine_schedule


class CoOpt(Base):
    def __init__(self, backbone, args):
        super().__init__(backbone, args)
        
        self.transfer = nn.Linear(self.args.feature_dim, self.args.coopt_dim) # self.args.coopt_dim
        self.criterion = NegativeCosineSimilarity()

    def forward(self, x):
        h = self.backbone(x).flatten(start_dim=1)
        z = self.transfer(h)
        return z, h

    def training_step(self, batch, batch_idx):
        ((x0, x1), z), y = batch
        # (x0, z), y = batch    ####
        z_hat, h_hat = self.forward(x0)

        # z_hat_1, h_hat_1 = self.forward(x1)
        
        loss = self.criterion(z_hat, z)

        # loss = (self.criterion(z_hat_1, z) + self.criterion(z_hat, z)) * 0.5
        
        loss_ce = self.online_classifier_training_step(h_hat, y)
        self.log("train_loss", loss_ce)
        return loss + loss_ce
    
    def configure_optimizers(self):

        return torch.optim.AdamW(
            self.parameters(),
            lr=self.args.learning_rate,
            weight_decay=self.args.weight_decay,
        )
    


class Transform:
    """
    A stochastic data augmentation module that transforms any given data example randomly
    resulting in two correlated views of the same example,
    denoted x ̃i and x ̃j, which we consider as a positive pair.
    """

    def __init__(self, input_size=32):

        s = 1
        color_jitter = torchvision.transforms.ColorJitter(
            0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s
        )
        self.transform = torchvision.transforms.Compose(
            [
                torchvision.transforms.ToPILImage(),
                torchvision.transforms.RandomResizedCrop(
                    size=input_size, scale=(0.08, 1.0)
                ),
                torchvision.transforms.RandomHorizontalFlip(),  # with 0.5 probability
                torchvision.transforms.RandomApply([color_jitter], p=0.8),
                torchvision.transforms.RandomGrayscale(p=0.2),
                torchvision.transforms.ToTensor(),
                torchvision.transforms.Normalize(
                    mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
                ),
            ]
        )

    def __call__(self, x):
        # return self.transform(x)
        return self.transform(x), self.transform(x)
    

def TransformCoOpt(dataset, input_size):
    
    transform = Transform(input_size=input_size)
    # transform = BYOLView1Transform(input_size=input_size, gaussian_blur=0.0)
    
    return transform