from copy import deepcopy

import torch
import torch.nn as nn
import torch.nn.functional as F
from scipy.optimize import linear_sum_assignment

from models.cif import supcon
from models.resnet18_encoder import *
from models.resnet20_cifar import *
from models.resnet12 import *
from tqdm import tqdm
from sklearn.metrics.pairwise import cosine_similarity

from scheduler.lr_scheduler import LinearWarmupCosineAnnealingLR


# from models.dualnet18_encoder import *
class projection_MLP(nn.Module):
    def __init__(self, in_dim, out_dim, num_layers=2):
        super().__init__()
        hidden_dim = out_dim

        self.num_layers = num_layers

        self.layer1 = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True)
        )
        self.layer2 = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True)
        )
        self.layer3 = nn.Sequential(
            nn.Linear(hidden_dim, out_dim),
            nn.BatchNorm1d(out_dim, affine=False)  # Page:5, Paragraph:2
        )

    def forward(self, x):
        if self.num_layers == 2:
            x = self.layer1(x)
            x = self.layer3(x)
        elif self.num_layers == 3:
            x = self.layer1(x)
            x = self.layer2(x)
            x = self.layer3(x)
        return x


class MYNET(nn.Module):

    def __init__(self, args, mode=None):
        super().__init__()
        self.m = 0.35
        self.m_arg = 0.5

        self.mode = mode
        self.args = args

        if self.args.dataset in ['cifar100', 'manyshotcifar']:
            self.encoder = resnet12_nc()
            # self.encoder = nn.Sequential(resnet20())
            # self.encoder = resnet18(False, args)  # pretrained=False
            # self.num_features = 1024
            # self.num_features = 64
            self.num_features = 640
            self.proj_hidden_dim = 2048
            self.proj_output_dim = 128
            self.encoder_outdim = 64
        if self.args.dataset in ['mini_imagenet', 'manyshotmini', 'imagenet100', 'imagenet1000',
                                 'mini_imagenet_withpath']:
            # self.encoder = resnet18(False, args)  # pretrained=False
            self.encoder = resnet12_nc()
            # self.encoder = MaskNet18(args.num_classes)
            # self.num_features = 512
            self.num_features = 640
            self.proj_hidden_dim = 2048
            self.proj_output_dim = 128
            self.encoder_outdim = 512

        if self.args.dataset in ['cub200', 'manyshotcub']:
            self.encoder = resnet18(True,
                                    args,
                                    mode="parallel_adapters")  # pretrained=True follow TOPIC, models for cub is imagenet pre-trained. https://github.com/xyutao/fscil/issues/11#issuecomment-687548790
            # self.encoder = MaskNet18(args.num_classes, pretrain=True)
            self.num_features = 512
            self.proj_hidden_dim = 2048
            self.proj_output_dim = 256
            self.encoder_outdim = 512

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        # self.projection_features = 256
        # 投影层，将特征映射到新的空间
        # self.projection = nn.Linear(self.num_features, self.projection_features, bias=False)  # 示例，将特征映射到 256 维
        self.pre_allocate = self.args.num_classes
        # self.projector = projection_MLP(self.num_features, self.proj_hidden_dim, args.num_proj_layers)

        self.fc = nn.Linear(self.num_features, self.args.num_classes, bias=False)
        #
        # out_features = int(self.args.num_classes + (self.args.base_class - 1) * self.args.base_class / 2)
        # self.fc = nn.Linear(self.num_features, out_features, bias=False)
        # Select the projection head ('g' from the main paper)

        # Final classifier. This hosts the pseudo targets, all and classification happens here
        # self.fc = PseudoTargetClassifier(self.args, self.proj_output_dim)
        # self.fc_base = nn.Linear(self.proj_hidden_dim, out_features, bias=False)

    def fix_backbone(self):
        """Freeze the backbone domain-agnostic"""
        for k, v in self.encoder.named_parameters():
            if ('adapter' not in k) and ('cls' not in k) and ('running' not in k):
                v.requires_grad = False

    def fix_backbone_adapter(self):
        """Freeze the backbone domain-agnostic"""
        for k, v in self.encoder.named_parameters():
            if 'adapter' not in k:
                v.requires_grad = False

    def train_backbone(self):
        """Freeze the backbone domain-agnostic"""
        for k, v in self.encoder.named_parameters():
            if 'adapter' not in k and 'alpha' not in k and 'beta' not in k:
                v.requires_grad = True
            else:
                v.requires_grad = False
                # 遍历所有模块，冻结非适配器相关的BatchNorm层
        if hasattr(self, 'fc'):
            for param in self.fc.parameters():
                param.requires_grad = True
        if hasattr(self, 'projector'):
            for param in self.projector.parameters():
                param.requires_grad = True

    def train_adapter(self):
        """Freeze the backbone domain-agnostic and the fc layer"""
        # Freeze all parameters first
        for param in self.parameters():
            param.requires_grad = False

        # Unfreeze adapter parameters
        for name, param in self.encoder.named_parameters():
            if 'adapter' in name:
                param.requires_grad = True

        # Freeze specific layers like self.fc
        if hasattr(self, 'fc'):
            for param in self.fc.parameters():
                param.requires_grad = False
        if hasattr(self, 'projector'):
            for param in self.projector.parameters():
                param.requires_grad = True

        # Traverse all modules, set non-adapter related BatchNorm layers to eval mode but allow scale and shift to update
        for name, module in self.encoder.named_modules():
            if isinstance(module, (nn.BatchNorm2d, nn.BatchNorm1d)):
                if 'adapter' not in name:
                    module.eval()  # Set BN layers to eval mode to use running mean/variance
                    # Keep scale (gamma) and shift (beta) trainable
                    if module.weight is not None:
                        module.weight.requires_grad = False
                    if module.bias is not None:
                        module.bias.requires_grad = False
                if 'adapter' in name:
                    # module.eval()  # Set BN layers to eval mode to use running mean/variance
                    # # Keep scale (gamma) and shift (beta) trainable
                    if module.weight is not None:
                        module.weight.requires_grad = True
                    if module.bias is not None:
                        module.bias.requires_grad = True

    def train_backbone_all(self):
        """Freeze the backbone domain-agnostic"""
        for k, v in self.encoder.named_parameters():
            v.requires_grad = True

    def fix_backbone_all(self):
        """Freeze the backbone domain-agnostic"""
        for k, v in self.encoder.named_parameters():
            v.requires_grad = False

    def forward_metric(self, x):
        x = self.encode(x)
        # self.end_points['final_feature'] = x
        if 'cos' in self.mode:
            x = F.linear(F.normalize(x, p=2, dim=-1), F.normalize(self.fc.weight, p=2, dim=-1))

            x = self.args.temperature * x
        elif 'dot' in self.mode:
            x = self.fc(x)
            x = self.args.temperature * x
        return x

    def encode(self, x, detach_f=False, return_encodings=False):
        x = self.encoder(x)
        x = F.adaptive_avg_pool2d(x, 1)
        x = x.squeeze(-1).squeeze(-1)

        # Only finetuning the projector
        if detach_f: encodings = x.detach()

        # x = self.projector(x)

        if return_encodings:
            return x, encodings

        return x

    def forward(self, input):
        if self.mode != 'encoder':
            input = self.forward_metric(input)
            return input
        elif self.mode == 'encoder':
            input = self.encode(input)
            return input
        else:
            raise ValueError('Unknown mode')

    def update_fc(self, dataloader, class_list, session):
        for batch in dataloader:
            if session == 0:
                data, label = [_.cuda() for _ in batch]
            else:
                data, label = [item.cuda() for item in batch]
            data = self.encode(data).detach()

        if self.args.not_data_init:
            new_fc = nn.Parameter(
                torch.rand(len(class_list), self.num_features, device="cuda"),
                requires_grad=True)
            nn.init.kaiming_uniform_(new_fc, a=math.sqrt(5))
        else:
            new_fc = self.update_fc_avg(data, label, class_list)

    def update_fc_avg(self, data, label, class_list):
        new_fc = []
        for class_index in class_list:
            data_index = (label == class_index).nonzero().squeeze(-1)
            embedding = data[data_index]
            proto = embedding.mean(0)
            new_fc.append(proto)
            self.fc.weight.data[class_index] = proto
        new_fc = torch.stack(new_fc, dim=0)
        return new_fc

    def get_logits(self, x, fc):
        if 'dot' in self.args.new_mode:
            return F.linear(x, fc)
        elif 'cos' in self.args.new_mode:
            return self.args.temperature * F.linear(F.normalize(x, p=2, dim=-1), F.normalize(fc, p=2, dim=-1))

    def soft_calibration(self, args, session, k=5):
        # # # base_protos = self.fc.weight.data[:args.base_class].detach().cpu().data
        # # base_protos = prototypes[:args.base_class + (session - 1) * args.way]
        # # self.fc.weight.data[:args.base_class + (session - 1) * args.way] = base_protos
        # # base_protos = F.normalize(base_protos, p=2, dim=-1)
        # #
        # # # cur_protos = self.fc.weight.data[args.base_class + (
        # # # session - 1) * args.way: args.base_class + session * args.way].detach().cpu().data
        # # cur_protos = prototypes[args.base_class + (
        # #         session - 1) * args.way: args.base_class + session * args.way]
        # # self.fc.weight.data[
        # # args.base_class + (session - 1) * args.way: args.base_class + session * args.way] = cur_protos
        # # cur_protos = F.normalize(cur_protos, p=2, dim=-1)
        # #
        # # weights = torch.mm(cur_protos, base_protos.T) * args.softmax_t
        # # norm_weights = torch.softmax(weights, dim=1)
        # # delta_protos = torch.matmul(norm_weights, base_protos)
        # # #
        # # delta_protos = F.normalize(delta_protos, p=2, dim=-1)
        # #
        # # updated_protos = (1 - args.shift_weight) * cur_protos + args.shift_weight * delta_protos
        # # # updated_protos_base = (1 - args.shift_weight) * base_protos + args.shift_weight * delta_protos
        # 计算余弦相似度
        # base_protos = prototypes[:args.base_class]
        # cur_protos = prototypes[args.base_class + (session - 1) * args.way: args.base_class + session * args.way]
        # cos_sim = torch.mm(F.normalize(cur_protos, p=2, dim=-1),
        #                    F.normalize(base_protos, p=2, dim=-1).T) * args.softmax_t
        #
        # # 对于每个cur_proto找到最近的五个base_protos的索引
        # _, topk_indices = torch.topk(cos_sim, k=k, dim=1, largest=True)
        #
        # # 使用索引获取相应的base_protos并计算平均值
        # topk_base_protos = torch.stack([base_protos[idx] for idx in topk_indices], dim=0)
        # mean_delta_protos = torch.mean(topk_base_protos, dim=1)
        #
        # # 正规化mean_delta_protos
        # delta_protos = F.normalize(mean_delta_protos, p=2, dim=-1)
        # cur_protos = F.normalize(cur_protos, p=2, dim=-1)
        #
        # # 使用新的delta_protos进行更新
        # updated_protos = (1 - args.shift_weight) * cur_protos + args.shift_weight * delta_protos
        #
        # # 以下是对base_protos和cur_protos进行权重更新的操作，保留原始操作的结构
        # # base_protos = F.normalize(base_protos, p=2, dim=-1)
        # # self.fc.weight.data[:args.base_class] = base_protos
        #
        # # cur_protos = updated_protos
        # self.fc.weight.data[
        # args.base_class + (session - 1) * args.way: args.base_class + session * args.way] = updated_protos
        # # cur_protos = F.normalize(cur_protos, p=2, dim=-1)
        #
        # # # 计算新的权重矩阵并应用softmax
        # weights = torch.mm(cur_protos, base_protos.T) * args.softmax_t
        # norm_weights = torch.softmax(weights, dim=1)
        base_protos = self.fc.weight.data[:args.base_class + (
                session - 1) * args.way].detach().cpu().data
        base_protos = F.normalize(base_protos, p=2, dim=-1)

        cur_protos = self.fc.weight.data[args.base_class + (
                session - 1) * args.way: args.base_class + session * args.way].detach().cpu().data
        cur_protos = F.normalize(cur_protos, p=2, dim=-1)

        weights = torch.mm(cur_protos, base_protos.T)* args.softmax_t
        norm_weights = torch.softmax(weights, dim=1)
        delta_protos = torch.matmul(norm_weights, base_protos)

        delta_protos = F.normalize(delta_protos, p=2, dim=-1)

        updated_protos = (1 - args.shift_weight) * cur_protos + args.shift_weight * delta_protos

        self.fc.weight.data[
        args.base_class + (session - 1) * args.way: args.base_class + session * args.way] = updated_protos
        # base_protos = self.fc.weight.data[:args.base_class].detach().cpu().data
        # base_protos = F.normalize(base_protos, p=2, dim=-1)
        #
        # cur_protos = self.fc.weight.data[args.base_class + (
        #         session - 1) * args.way: args.base_class + session * args.way].detach().cpu().data
        # cur_protos = F.normalize(cur_protos, p=2, dim=-1)
        #
        # weights = torch.mm(cur_protos, base_protos.T) * args.softmax_t
        # norm_weights = torch.softmax(weights, dim=1)
        # delta_protos = torch.matmul(norm_weights, base_protos)
        #
        # delta_protos = F.normalize(delta_protos, p=2, dim=-1)
        #
        # updated_protos = (1 - args.shift_weight) * cur_protos + args.shift_weight * delta_protos
        #
        # self.fc.weight.data[
        # args.base_class + (session - 1) * args.way: args.base_class + session * args.way] = updated_protos
