import os
import math
import copy
import torch
import random
import pickle
import torch.nn as nn
import torch.nn.functional as F

import utils.util as util
from typing import List

def L2_importance(w1, w2):
    delta = (w1 - w2)**2
    return delta.sum(dim=1)

def cosine_importance(w1, w2):
    return -torch.cosine_similarity(w1, w2, dim=1)


class PSLayer():
    def __init__(self, out_features: int):
        self.similarity_rank = torch.arange(0, out_features)
        self.imps = torch.arange(0, out_features)
        self.ps_p = 0.2
        self.trained = False
        

    def reset_mask(self, p=0.2, threshold=None):
        if threshold is None:
            assert 0 <= p <= 1
            self.ps_p = p
            # chosen_num = math.ceil(len(self.ps_mask) * p)
            upper_bound = math.ceil(len(self.ps_mask) * p)
            chosen_idx = self.similarity_rank[:upper_bound]
            self.ps_mask[:] = 0
            self.ps_mask[chosen_idx] = 1
        else:
            self.ps_mask[:] = (self.imps >= threshold).float()[:]
    
    def compute_channel_importance(self, imp_fn):
        pass
        
    def copy_params(self):
        pass      

class Linear(nn.Linear, PSLayer):
    def __init__(
        self, 
        in_features: int, 
        out_features: int, 
        **kwargs):
        nn.Linear.__init__(self, in_features, out_features, **kwargs)
        PSLayer.__init__(self, out_features)
        ps_mask = torch.zeros(out_features)
        self.register_buffer("ps_mask", ps_mask)
        self.ps_weight = copy.deepcopy(self.weight)
        self.ps_bias = copy.deepcopy(self.bias)
        self.reset_parameters()
        
    def forward(self, input):
        if not self.trained:
            parent_out = F.linear(input, self.weight, self.bias)
            child_out = F.linear(input, self.ps_weight, self.ps_bias)
            return child_out * self.ps_mask + parent_out * (1 - self.ps_mask)
        else:
            return F.linear(input, self.weight, self.bias)
    
    def reset_parameters(self):
        nn.Linear.reset_parameters(self)
        if hasattr(self, "ps_weight"):
            nn.init.kaiming_uniform_(self.ps_weight, a=math.sqrt(5))
            if self.ps_bias is not None:
                fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.ps_weight)
                bound = 1 / math.sqrt(fan_in)
                nn.init.uniform_(self.ps_bias, -bound, bound)
        
    def copy_params(self):
        self.ps_weight = copy.deepcopy(self.weight)
        if self.bias is not None:
            self.ps_bias = copy.deepcopy(self.bias)
        else:
            self.ps_bias = None

    def compute_channel_importance(self, imp_fn=cosine_importance):
        self.imps = imp_fn(self.weight, self.ps_weight)
        _ , self.similarity_rank = torch.sort(self.imps, descending=True)
        return self.imps
        
    def set_trainable_params(self):
        for p in self.parameters():
            p.requires_grad = False
        self.ps_weight.requires_grad = True
        if self.ps_bias is not None:
            self.ps_bias.requires_grad = True
        
class Conv2d(nn.Conv2d, PSLayer):
    def __init__(self, in_channels, out_channels, kernel_size, **kwargs):
        nn.Conv2d.__init__(self, in_channels, out_channels, kernel_size, **kwargs)
        PSLayer.__init__(self, out_channels)
        ps_mask = torch.zeros(out_channels)
        self.register_buffer("ps_mask", ps_mask)
        self.ps_weight = copy.deepcopy(self.weight)
        self.ps_bias = copy.deepcopy(self.bias)
        self.reset_parameters()
    
    def forward(self, input):
        # print(input.shape)
        if not self.trained:
            mask = self.ps_mask

            # w =  self.ps_weight.data.permute(1,2,3,0) * mask + self.weight.data.permute(1,2,3,0)*(1 - mask)
            # w = w.permute(3,0,1,2)
            # b = None if self.bias is None else self.ps_bias.data * mask + self.bias.data * (1 - mask)
            # return self._conv_forward(input, w, b)
            parent_out = self._conv_forward(input, self.weight, self.bias).permute(0, 2, 3, 1)
            child_out = self._conv_forward(input, self.ps_weight, self.ps_bias).permute(0, 2, 3, 1)
            res = (child_out * mask + parent_out * (1 - mask)).permute(0, 3, 1, 2)
        else:
            res = self._conv_forward(input, self.weight, self.bias)
        return res
    
    @torch.no_grad()
    def post_processing(self):
        self.trained = True
        mask = self.ps_mask
        self.weight.data = self.ps_weight.data.permute(1,2,3,0) * mask + \
            self.weight.data.permute(1, 2, 3, 0)*(1 - mask)
        self.weight.data = self.weight.permute(3, 0, 1, 2)
        if self.bias is not None:
            self.bias.data = self.ps_bias.data * mask + self.bias.data * (1 - mask)
        self.ps_bias = self.ps_weight = self.ps_mask = None
    
    def reset_parameters(self):
        nn.Conv2d.reset_parameters(self)
        if hasattr(self, "ps_weight"):
            nn.init.kaiming_uniform_(self.ps_weight, a=math.sqrt(5))
            if self.ps_bias is not None:
                fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.ps_weight)
                bound = 1 / math.sqrt(fan_in)
                nn.init.uniform_(self.ps_bias, -bound, bound)
    
    def copy_params(self):
        self.ps_weight = copy.deepcopy(self.weight)
        if self.bias is not None:
            self.ps_bias = copy.deepcopy(self.bias)
        else:
            self.ps_bias = None
    
    def compute_channel_importance(self, imp_fn=cosine_importance):
        self.imps = imp_fn(self.weight.reshape(self.out_channels, -1), self.ps_weight.reshape(self.out_channels, -1))
        _ , self.similarity_rank = torch.sort(self.imps, descending=True)
        return self.imps
        
    def set_trainable_params(self):
        for p in self.parameters():
            p.requires_grad = False
        self.ps_weight.requires_grad = True
        if self.ps_bias is not None:
            self.ps_bias.requires_grad = True

SHARE_THRESHOLD = 0.3

class TConv2d(nn.Conv2d):
    def __init__(self, in_channels, out_channels, kernel_size, **kwargs):
        nn.Conv2d.__init__(self, in_channels, out_channels, kernel_size, **kwargs)
        ps_mask = torch.zeros(out_channels)
        self.ps_mask = nn.Parameter(ps_mask)
        self.ps_weight = copy.deepcopy(self.weight)
        self.ps_bias = copy.deepcopy(self.bias)
        self.reset_parameters()
        self.soft_mask = True
        nn.init.normal_(self.ps_mask)
        self.trained = False
    
    def forward(self, input):
        # print(input.shape)
        if not self.trained:
            mask = torch.sigmoid(self.ps_mask)
            if not self.soft_mask:
                # trainable_mask = (mask > (1.0 - SHARE_THRESHOLD)).float()
                # shared_mask = (mask < SHARE_THRESHOLD).float()
                # mask = trainable_mask + (1 - trainable_mask)*(1 - shared_mask)*mask
                mask = (mask > SHARE_THRESHOLD).float()

            parent_out = self._conv_forward(input, self.weight, self.bias).permute(0, 2, 3, 1)
            child_out = self._conv_forward(input, self.ps_weight, self.ps_bias).permute(0, 2, 3, 1)
            res = (child_out * mask + parent_out * (1 - mask)).permute(0, 3, 1, 2)
        else:
            res = self._conv_forward(input, self.weight, self.bias)
        return res
    
    @torch.no_grad()
    def post_processing(self):
        self.trained = True
        
        mask = torch.sigmoid(self.ps_mask)
        # trainable_mask = (mask > (1.0 - SHARE_THRESHOLD)).float()
        # shared_mask = (mask < SHARE_THRESHOLD).float()
        # mask = trainable_mask + (1 - trainable_mask)*(1 - shared_mask)*mask
        mask = (mask > SHARE_THRESHOLD).float()
        
        self.weight.data = self.ps_weight.data.permute(1,2,3,0) * mask + \
            self.weight.data.permute(1, 2, 3, 0)*(1 - mask)
        self.weight.data = self.weight.permute(3, 0, 1, 2)
        if self.bias is not None:
            self.bias.data = self.ps_bias.data * mask + self.bias.data * (1 - mask)
        self.ps_bias = self.ps_weight = self.ps_mask = None
    
    def reset_parameters(self):
        nn.Conv2d.reset_parameters(self)
        if hasattr(self, "ps_weight"):
            nn.init.kaiming_uniform_(self.ps_weight, a=math.sqrt(5))
            if self.ps_bias is not None:
                fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.ps_weight)
                bound = 1 / math.sqrt(fan_in)
                nn.init.uniform_(self.ps_bias, -bound, bound)
    
    def copy_params(self):
        self.ps_weight = copy.deepcopy(self.weight)
        if self.bias is not None:
            self.ps_bias = copy.deepcopy(self.bias)
        else:
            self.ps_bias = None
    
    def prune_layer(self):
        self.soft_mask = False
        
    def get_sparse_ratio(self):
        mask = torch.sigmoid(self.ps_mask)
        return sum(mask < SHARE_THRESHOLD), sum((mask - SHARE_THRESHOLD) ** 2)
    
    def reset_mask(self, p):
        return
        
    def set_trainable_params(self):
        for p in self.parameters():
            p.requires_grad = False
        self.ps_weight.requires_grad = True
        self.ps_mask.requires_grad = True
        if self.ps_bias is not None:
            self.ps_bias.requires_grad = True
      
class TLinear(nn.Linear):
    def __init__(self, in_channels, out_channels, **kwargs):
        nn.Linear.__init__(self, in_channels, out_channels, **kwargs)
        ps_mask = torch.zeros(out_channels)
        self.ps_mask = nn.Parameter(ps_mask)
        self.ps_weight = copy.deepcopy(self.weight)
        self.ps_bias = copy.deepcopy(self.bias)
        self.reset_parameters()
        self.soft_mask = True
        nn.init.normal_(self.ps_mask)
        self.trained = False

    
    def forward(self, input):
        # print(input.shape)
        if not self.trained:
            mask = torch.sigmoid(self.ps_mask)
            if not self.soft_mask:
                mask = (mask > SHARE_THRESHOLD).float()

            parent_out = F.linear(input, self.weight, self.bias)
            child_out = F.linear(input, self.ps_weight, self.ps_bias)
            res = child_out * mask + parent_out * (1 - mask)
        else:
            res = F.linear(input, self.weight, self.bias)
        return res
    
    @torch.no_grad()
    def post_processing(self):
        self.trained = True
        
        mask = torch.sigmoid(self.ps_mask)
        # trainable_mask = (mask > (1.0 - SHARE_THRESHOLD)).float()
        # shared_mask = (mask < SHARE_THRESHOLD).float()
        # mask = trainable_mask + (1 - trainable_mask)*(1 - shared_mask)*mask
        mask = (mask > SHARE_THRESHOLD).float()
        
        self.weight.data = (self.ps_weight.data.permute(1, 0) * mask + self.weight.data.permute(1, 0) * (1 - mask)).permute(1, 0)
        if self.bias is not None:
            self.bias.data = self.ps_bias.data * mask + self.bias.data * (1 - mask)
        self.ps_bias = self.ps_weight = self.ps_mask = None
    
    def reset_parameters(self):
        nn.Linear.reset_parameters(self)
    
    def copy_params(self):
        self.ps_weight = copy.deepcopy(self.weight)
        if self.bias is not None:
            self.ps_bias = copy.deepcopy(self.bias)
        else:
            self.ps_bias = None
    
    def prune_layer(self):
        self.soft_mask = False
        
    def get_sparse_ratio(self):
        mask = torch.sigmoid(self.ps_mask)
        return sum(mask < SHARE_THRESHOLD), sum((mask - 0.5) ** 2)
    
    def reset_mask(self, p):
        return
        
    def set_trainable_params(self):
        for p in self.parameters():
            p.requires_grad = False
        self.ps_weight.requires_grad = True
        self.ps_mask.requires_grad = True
        if self.ps_bias is not None:
            self.ps_bias.requires_grad = True
      
class ModelUnion():
    def __init__(self, root_model, train_sets, val_sets, num_classes, task_names):
        self.models: List[nn.Module] = [root_model] + [None for _ in range(len(train_sets))]
        self.train_sets = [None] + train_sets
        self.val_sets = [None] + val_sets
        self.num_classes = [1000] + num_classes
        self.accs = [0 for _ in range(len(train_sets) + 1)]
        self.task_names = ["ImageNet"] + task_names
        
    def get_task_context(self, args, idx, prefix="fc"):
        assert 1 <= idx < len(self.models) 
        model = self.get_average_pretained_model(idx, self.num_classes[idx], prefix=prefix)
        return {
            "model": model,
            "train_set": self.train_sets[idx],
            "val_set": self.val_sets[idx],
            "task": self.task_names[idx]
        }
    
    @torch.no_grad()
    def get_average_pretained_model(self, idx, num_classes, average=False, prefix="fc"):
        new_model = copy.deepcopy(self.models[0])
        out_layer = new_model.get_submodule(prefix)
        out_layer_name = prefix.split(".")[-1]
        father_module = new_model if "." not in prefix else new_model.get_submodule(prefix[:len(prefix) - len(out_layer_name) - 1])
        setattr(father_module, out_layer_name, nn.Linear(
            out_layer.in_features,
            num_classes
        ))
        if average:
            trained_models = [m for i, m in enumerate(self.models) if m is not None and i != idx][1:]
            # if len(trained_models) > 3:
            #     trained_models = random.sample(trained_models, 3)
            stat_dicts = [m.state_dict() for m in trained_models]
            new_model_state_dict = new_model.state_dict()
            for k, v in new_model_state_dict.items():
                if k in [prefix + '.weight', prefix + '.bias']:
                    continue
                for trained_stat in stat_dicts:
                    new_model_state_dict[k] += trained_stat[k]
                new_model_state_dict[k] = new_model_state_dict[k] / (len(trained_models) + 1)
            new_model.load_state_dict(new_model_state_dict)
        for p in new_model.parameters():
            p.requires_grad = True
        return new_model
    
    @torch.no_grad()
    def convert_ps_model(self, model: nn.Module, task_idx, prefix="fc"):
        util.convert_ps(model, ignore_layers=[model.get_submodule(prefix)], mask_trainable=True)
        util.ps_model_init(model, p=1)
        trained_model_idxs = [idx for idx, m in enumerate(self.models) if m is not None and idx != task_idx]
        trained_models = [m for idx, m in enumerate(self.models) if m is not None and idx != task_idx]
        imp_fn = L2_importance
        visual_data = []
        for n, module in model.named_modules():
            if type(module) in [Conv2d, Linear, TConv2d, TLinear]:
                root_model_layer = trained_models[0].get_submodule(n)
                module.weight = copy.deepcopy(root_model_layer.weight)
                module.bias = copy.deepcopy(root_model_layer.bias)
                weight_shape = module.weight.shape
                permute_params = list(range(1, len(weight_shape))) + [0]
                resume_params = [len(weight_shape) - 1] + list(range(0, len(weight_shape) - 1))
                c = weight_shape[0]
                c_from = torch.zeros(c)
                diff = imp_fn(module.weight.reshape(c, -1), module.ps_weight.reshape(c, -1))
                for m_idx, m in enumerate(trained_models[1:]):
                    trained_layer = m.get_submodule(n)
                    new_diff = imp_fn(trained_layer.weight.reshape(c, -1), module.ps_weight.reshape(c, -1))
                    channel_filter = (new_diff < diff).float()
                    # print(channel_filter)
                    c_from = (c_from * (1 - channel_filter)) + float(trained_model_idxs[m_idx + 1]) * channel_filter # record where the weight from
                    diff = channel_filter * new_diff + (1 - channel_filter) * diff
                    module.weight.data = (channel_filter * trained_layer.weight.data.permute(*permute_params) + 
                                          (1 - channel_filter) * module.weight.data.permute(*permute_params)).permute(*resume_params)
                    if module.bias is not None:
                        module.bias.data = channel_filter * trained_layer.bias.data + (1 - channel_filter) * module.bias.data
                module.ps_mask.data = ((diff - diff.mean()) / diff.var()) * 2.0
                visual_data.append(f"{n}: {c_from}\n")
        # TODO: figure out whether need to copy params
        # ps_model_init(model, p=1)
        util.mark_only_ps_as_trainable(model, ignore_layers=[model.get_submodule(prefix)]+[m for n,m in model.named_modules() if isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d))])
        return "".join(visual_data)
        
    def add_model(self, model, task_idx, acc):
        assert 1 <= task_idx < len(self.models)
        if self.accs[task_idx] <= acc:
            self.accs[task_idx] = acc
            self.models[task_idx] = model
            for n, module in model.named_modules():
                if type(module) in [Conv2d, Linear, TConv2d, TLinear]:
                    module.post_processing()
                    
    def save_models(self, path):
        if not os.path.exists(path):
            os.makedirs(path)
        for idx, m in enumerate(self.models):
            if m is not None:
                stat_dict = m.state_dict()
                new_dict = {}
                for k, v in stat_dict.items():
                    if "ps_weight" in k or "ps_bias" in k or "ps_mask" in k:
                        continue
                    new_dict[k] = v
                torch.save(new_dict, f"{path}/{idx}.pth")
        data = {
            "accs": self.accs,
            "models": [f"{path}/{idx}.pth" for idx, m in enumerate(self.models) if m is not None]
        }
        with open(f"{path}/model_infos.pkl", "wb") as f:
            pickle.dump(data, f)
    
    def load_models(self, path, prefix):
        with open(f"{path}/model_infos.pkl", "rb") as f:
            data = pickle.load(f)
            print(f"loading models from {path}/model_infos.pkl")
            self.accs = data["accs"]
            print(f"{[(t, d) for t, d in zip(self.task_names, self.accs)]}")
            for idx, m in enumerate(self.models):
                if m is None and (f'{path}/{idx}.pth' in data["models"]):
                    model = self.get_average_pretained_model(idx, self.num_classes[idx], average=False, prefix=prefix)
                    stat_dict = torch.load(f'{path}/{idx}.pth')
                    model.load_state_dict(stat_dict)
                    self.models[idx] = model