import torch, torchvision
import torch.nn as nn

from lightly.models.modules import SimCLRProjectionHead
from lightly.models.modules import SimSiamPredictionHead, SimSiamProjectionHead
from lightly.models.modules.heads import VICRegProjectionHead

from oucl.agents.losses import load_loss
from oucl.agents.optimization import load_optimizer, load_scheduler

from abc import  abstractmethod

import numpy as np

class BaseEncoder(nn.Module):

    def __init__(self):
        super(BaseEncoder, self).__init__()

    @abstractmethod
    def forward(self, x):
        pass

    @abstractmethod
    def embed(self, x):
        pass

    @abstractmethod
    def training_step(self):
        pass

    def freeze_layers(self, num_frozen):
        for l, param in enumerate(self.parameters()):
            if l >= num_frozen:
                break

            param.requires_grad = False
        self.num_frozen = num_frozen

    def configure_optimizers(self, config):
        self.optimizer = load_optimizer(filter(lambda p: p.requires_grad==True, self.parameters()),
                                        config.agent.optimizer,
                                        config)
        
        self.lr_scheduler = load_scheduler(self.optimizer,
                                          config.agent.lr_scheduler,
                                          config)


class SupervisedEncoder(BaseEncoder):

    def __init__(self, config):
        super(SupervisedEncoder, self).__init__()

        self.device = config.device

        self.backbone = load_backbone(config)
        self.linear = nn.Linear(self.backbone.out_dim, config.agent.num_classes)

        self.configure_optimizers(config)

        self.step = 0

    def embed(self, x):
        with torch.no_grad():
            return self.backbone(x.to('cuda'))

    def _forward(self, x):
        z = self.backbone(x).flatten(start_dim=1)
        y = self.linear(z)
        return {'z': z, 'y_pred': y}

    def forward(self, batch):
        outs = self._forward(batch.to(self.device))

        return outs
    
class SimCLREncoder(BaseEncoder):

    def __init__(self, config):
        super(SimCLREncoder, self).__init__()

        self.device = config.device

        self.backbone = load_backbone(config)
        self.projection_head = SimCLRProjectionHead(self.backbone.out_dim,
                                                    config.agent.hidden_dim,
                                                    config.agent.proj_dim)

        self.configure_optimizers(config)

        self.step = 0

    def embed(self, x):
        with torch.no_grad():
            return self.backbone(x.to('cuda'))

    def _forward(self, x):
        z = self.backbone(x).flatten(start_dim=1)
        h = self.projection_head(z)
        return {'z': z, 'h': h}

    def forward(self, batch):
        outs = [self._forward(batch[i].to(self.device)) for i in range(len(batch))]

        return {**{f'z{i}' : outs[i]['z'] for i in range(len(outs))}, **{f'h{i}' : outs[i]['h'] for i in range(len(outs))}}


class SimSiamEncoder(BaseEncoder):

    def __init__(self, config):
        super(SimSiamEncoder, self).__init__()

        self.device = config.device

        self.backbone = load_backbone(config)
        self.projection_head = SimSiamProjectionHead(self.backbone.out_dim,
                                                     self.backbone.out_dim,
                                                     self.backbone.out_dim)
        self.prediction_head = SimSiamPredictionHead(self.backbone.out_dim, 
                                                     config.agent.hidden_dim, 
                                                     self.backbone.out_dim)

        self.configure_optimizers(config)

        self.step = 0

    def embed(self, x):
        with torch.no_grad():
            return self.backbone(x.to('cuda'))

    def _forward(self, x):
        z = self.backbone(x).flatten(start_dim=1)
        h = self.projection_head(z)
        p = self.prediction_head(h)

        return {'z': z, 'h': h, 'p': p}
    
    def forward(self, batch):
        outs = [self._forward(batch[i].to(self.device)) for i in range(len(batch))]

        return {**{f'z{i}' : outs[i]['z'] for i in range(len(outs))}, 
                **{f'h{i}' : outs[i]['h'] for i in range(len(outs))},
                **{f'p{i}' : outs[i]['p'] for i in range(len(outs))}}
    


class VicRegEncoder(BaseEncoder):

    def __init__(self, config):
        super(VicRegEncoder, self).__init__()

        self.device = config.device

        self.backbone = load_backbone(config)
        self.projection_head = VICRegProjectionHead(self.backbone.out_dim,
                                                     config.agent.hidden_dim,
                                                     config.agent.proj_dim,
                                                     num_layers=2)


        self.configure_optimizers(config)

        self.step = 0

    def embed(self, x):
        with torch.no_grad():
            return self.backbone(x.to('cuda'))

    def _forward(self, x):
        z = self.backbone(x).flatten(start_dim=1)
        h = self.projection_head(z)

        return {'z': z, 'h': h}
    
    def forward(self, batch):
        outs = [self._forward(batch[i].to(self.device)) for i in range(len(batch))]

        return {**{f'z{i}' : outs[i]['z'] for i in range(len(outs))}, 
                **{f'h{i}' : outs[i]['h'] for i in range(len(outs))}}
    

def load_backbone(config):

    if config.agent.arch == 'resnet18':
        backbone = torchvision.models.resnet18()
        backbone.fc = torch.nn.Identity()
        backbone.out_dim = 512
        if config.agent.pretrained:
            backbone.load_state_dict(torch.load(config.agent.pretrained_path))
        return backbone
    elif config.agent.arch == 'resnet34':
        backbone = torchvision.models.resnet34()
        backbone.fc = torch.nn.Identity()
        backbone.out_dim = 512
        return backbone
    elif config.agent.arch == 'resnet50':
        backbone = torchvision.models.resnet50()
        backbone.fc = torch.nn.Identity()
        backbone.out_dim = 2048
        return backbone

def load_encoder(config):
    if config.agent.loss == 'simclr':
        model = SimCLREncoder(config)
    elif config.agent.loss == 'simsiam':
        model = SimSiamEncoder(config)
    elif config.agent.loss == 'vicreg':
        model = VicRegEncoder(config)
    elif config.agent.loss == 'super':
        model = SupervisedEncoder(config)
    elif config.agent.loss == 'scale':
        model = SimCLREncoder(config)
    
    #if config.parallel:
    #    model = nn.DataParallel(model, device_ids=[i for i in range(torch.cuda.device_count())])
    #    model.embed = model.module.embed
    
    model.to(config.device)

    return model