import cv2
import copy
import logging
import numpy as np
import torch
from torch import nn
import torch.utils.model_zoo as model_zoo
from torch.autograd import Variable
import torch.nn.functional as F
from itertools import chain

from inclearn.lib import factory
from inclearn.lib.network.mlp import Fusion, MLP, MaxOut_MLP

logger = logging.getLogger(__name__)


class BasicNet(nn.Module):

    def __init__(
        self,
        dataset_name,
        n_classes,
        arch=None,
        device=None
    ):
        super(BasicNet, self).__init__()
        self.device = device
        
        self.normalize_before_fusion = True

        self.encoders, self.encoders_freeze, total_dims = _get_encoders(dataset_name, self.device, arch)
        self.modalities = list(self.encoders.keys())
        self.fusion = MLP(input_dim=total_dims, hidden_dims=[2048,1024], input_dropout=0.3).to(self.device)

        self.classifier = nn.Linear(1024, n_classes).to(self.device)
        
        self.to(self.device)

    def eval(self):
        self.fusion.eval()
        self.classifier.eval()
        for _, e in self.encoders.items():
            e.eval()

    def train(self):
        self.fusion.train()
        self.classifier.train()
        for _, e in self.encoders.items():
            e.train()
            
    @property
    def features_dim(self):
        return self.classifier.in_features
        
    def on_task_end(self):
        pass

    def on_epoch_end(self):
        pass

    def forward(self, inputs):
        outputs = {}
        features = []
        for modality in self.modalities:
            outputs[f"features_{modality}"] = self.encoders[modality](inputs[modality]) #.squeeze()
            if outputs[f"features_{modality}"].dim() == 3:
                outputs[f"features_{modality}"] = outputs[f"features_{modality}"].squeeze(dim=2)
            
            if self.normalize_before_fusion:
                features.append(F.normalize(outputs[f"features_{modality}"], dim=1))
            else:
                features.append(outputs[f"features_{modality}"])
            
        outputs["features_fused"] = self.fusion(torch.cat(features, dim=1))        
        outputs["logits"] = self.classifier(outputs["features_fused"])
           
        return outputs

    def extract(self, inputs):
        outputs = {}
        features = []
        for modality in self.modalities:
            outputs[f"features_{modality}"] = self.encoders[modality](inputs[modality])
            if outputs[f"features_{modality}"].dim() == 3:
                outputs[f"features_{modality}"] = outputs[f"features_{modality}"].squeeze(dim=2)
                
            if self.normalize_before_fusion:
                features.append(F.normalize(outputs[f"features_{modality}"], dim=1))
            else:
                features.append(outputs[f"features_{modality}"])

        outputs["features_fused"] = self.fusion(torch.cat(features, dim=1))        
        return outputs

    def freeze(self, trainable=False, model="all"):
        if isinstance(model, str):
            model = [model]
            
        modules = {}   
        for m in model:
            if m == "all":
                modules['all'] = self
            elif m == "encoders":
                for k, e in self.encoders.items():
                    modules[k] = e 
            elif m == "fusion":
                modules['fusion'] = self.fusion
            elif m == "classifier":
                modules['classifier'] = self.classifier
            elif m in self.modalities:
                modules[m] = self.encoders[m]
            else:
                assert False, m

        for k, m in modules.items():
            if not isinstance(m, nn.Module):
                continue

            for name, param in m.named_parameters():
                if k in self.encoders_freeze and name in self.encoders_freeze[k]:
                    continue
                param.requires_grad = trainable

            if not trainable:
                m.eval()
            else:
                m.train()

        return self

    def get_group_parameters(self):
        groups = {"classifier": self.classifier.parameters(), "fusion": self.fusion.parameters()}
        
        for key, item in self.encoders.items():
            groups[f"encoder_{key}"] = [p for k, p in item.named_parameters() if k not in self.encoders_freeze[key]] #item.parameters()
            
        return groups

    def copy(self):
        return copy.deepcopy(self)
    
    def save(self, path):
        save_states = {
            'fusion': self.fusion.state_dict(), 
            'classifier': self.classifier.state_dict(), 
        }
        for key, item in self.encoders.items():
            save_states[f"encoder_{key}"] = item.state_dict()
        
        torch.save(save_states, path)      

    def load(self, path):
        save_states = torch.load(path)
        
        self.fusion.load_state_dict(save_states['fusion'])
        self.classifier.load_state_dict(save_states['classifier'])
        
        for key, item in self.encoders.items():
            item.load_state_dict(save_states[f"encoder_{key}"])        



def _get_encoders(dataset_name, device, arch=None):
    dataset_name = dataset_name.lower().strip()
    
    if dataset_name == "kitchen" or dataset_name == "ave":
        logger.info("Loading visual encoder")
        visual_encoder = factory.get_encoder('slowfast', device=device)
        visual_freeze = []
        for k, p in visual_encoder.model.named_parameters():
            if 'layer4' not in k:
                p.requires_grad = False   
                visual_freeze.append('model.'+k)
        
        logger.info("Loading audio encoder")
        audio_encoder = factory.get_encoder('resnet18', in_ch=1)
        audio_checkpoint = torch.load('pretrained/vggsound_avgpool.pth.tar')
        audio_state_dict = {}
        for key, item in audio_checkpoint['model_state_dict'].items():
            if 'fc' in key:
                continue
            audio_state_dict[key.replace('audnet.', '')] = item
        audio_encoder.load_state_dict(audio_state_dict)
        audio_freeze = []
        for k, p in audio_encoder.named_parameters():
            if 'layer4' not in k:
                p.requires_grad = False   
                audio_freeze.append(k)
    
        total_dims = visual_encoder.out_dim+audio_encoder.out_dim
        
        return {'visual': visual_encoder, 'audio': audio_encoder.to(device)}, {'visual': visual_freeze, 'audio': audio_freeze}, total_dims

    elif dataset_name == "uestc_mmea":
        logger.info("Loading visual encoder")
        visual_encoder = factory.get_encoder('slowfast', device=device)
        visual_freeze = []
        for k, p in visual_encoder.model.named_parameters():
            if 'layer4' not in k:
                p.requires_grad = False   
                visual_freeze.append('model.'+k)
        
        logger.info("Loading inertial encoder")
        inertial_encoder = torch.hub.load('OxWearables/ssl-wearables', 'harnet10', class_num=5, pretrained=True).feature_extractor
        inertial_encoder.out_dim = 1024
        inertial_freeze = []
        for k, p in inertial_encoder.named_parameters():
            if 'layer4' not in k and 'layer5' not in k:
                p.requires_grad = False   
                inertial_freeze.append(k)
        
        total_dims = visual_encoder.out_dim+inertial_encoder.out_dim
        
        return {'visual': visual_encoder, 'inertial': inertial_encoder.to(device)}, {'visual': visual_freeze, 'inertial': inertial_freeze}, total_dims

    elif dataset_name == "dkd":
        logger.info("Loading visual encoder")
        
        visual_encoder = factory.get_encoder('vit16', img_size=224, num_classes=5, drop_path_rate=0, global_pool=True)
        visual_checkpoint = torch.load('pretrained/RETFound_cfp_weights.pth', map_location='cpu')
        visual_encoder.load_state_dict(visual_checkpoint['model'], strict=False)
        visual_freeze = []
        for k, p in visual_encoder.named_parameters():
            if 'fc_norm.' in k:
                p.requires_grad = True
            elif 'blocks.' in k and int(k.split('.')[1]) >= 22:
                p.requires_grad = True
            else:
                p.requires_grad = False   
                visual_freeze.append(k)
        
        logger.info("Loading tabular encoder")
        tabular_encoder = MLP(input_dim=7, hidden_dims=[16,32,16])
        tabular_encoder.out_dim = 16
    
        total_dims = visual_encoder.out_dim + 16
        
        return {'visual': visual_encoder.to(device), 'tabular': tabular_encoder.to(device)}, {'visual': visual_freeze, 'tabular': []}, total_dims

    else:
        raise NotImplementedError("Unknown dataset {}.".format(dataset_name))
