import torch
import torch.nn as nn
from torch.nn import init
from transformers import CLIPModel
import math
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training, TaskType
import re

class Noise_Tracker(nn.Module):
    def __init__(self, feature_net, time_net):
        super().__init__()
        self.feature_net = feature_net
        self.time_net = time_net
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        y = torch.unbind(x,dim=2)
        z = []
        for i in range(len(y)):
            x = self.feature_net(y[i])
            z.append(x)
        x = torch.stack(z, 0).permute(1, 0, 2)
        x = self.time_net(x)
        x = self.sigmoid(x)
        return x

class Dino_Pure_Tracker(nn.Module):
    def __init__(self, model, time_net):
        super().__init__()
        self.model = model
        self.time_net = time_net
        self.sigmoid = nn.Sigmoid()
        self.freeze(self.model)
        self.cos = nn.CosineSimilarity(dim=1)

    def forward(self, x):
        y = torch.unbind(x,dim=2)
        z = []
        for i in range(len(y)):
            x = self.model(y[i])
            x = x.pooler_output
            # x = x.mean(dim=1)
            z.append(x)
        x = torch.stack(z, 0).permute(1, 0, 2)
        x = self.time_net(x)
        x = self.sigmoid(x)
        return x

    @staticmethod
    def freeze(model):
        for param in model.parameters():
            param.requires_grad = False

class Dino_Gram_Tracker(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model
        self.sigmoid = nn.Sigmoid()
        self.freeze(self.model)
        self.norm = nn.BatchNorm1d(100)
        self.fc = nn.Linear(in_features=100, out_features=2, bias=True)

    def forward(self, x):
        y = torch.unbind(x,dim=2)
        z = []
        for i in range(len(y)):
            x = self.model(y[i])
            x = x.pooler_output
            # x = x.mean(dim=1)
            z.append(x)
        x = torch.stack(z, 0).permute(1, 0, 2)
        x_t = x.transpose(1, 2)
        x = torch.bmm(x, x_t)
        x = x.flatten(start_dim=1)
        x = self.norm(x)
        x = self.fc(x)
        x = self.sigmoid(x)
        return x

    @staticmethod
    def freeze(model):
        for param in model.parameters():
            param.requires_grad = False

class Dino_Tracker(nn.Module):
    def __init__(self, model, time_net):
        super().__init__()
        self.model = model
        self.time_net = time_net
        self.sigmoid = nn.Sigmoid()
        self.freeze(self.model)

    def forward(self, x):
        y = torch.unbind(x,dim=2)
        z = []
        for i in range(len(y)):
            x = self.model(y[i])
            x = x.pooler_output
            # x = x.mean(dim=1)
            if i == 0:
                ori = x
                continue
            z.append(x - ori)
            # z.append(x)
        x = torch.stack(z, 0).permute(1, 0, 2)
        x = self.time_net(x)
        x = self.sigmoid(x)
        return x

    @staticmethod
    def freeze(model):
        for param in model.parameters():
            param.requires_grad = False

class Dino_Single_Tracker(nn.Module):
    def __init__(self, model, time_net):
        super().__init__()
        self.model = model
        self.time_net = time_net
        self.sigmoid = nn.Sigmoid()
        self.freeze(self.model)

    def forward(self, x):
        y = torch.unbind(x,dim=2)
        z = []
        l = []
        for i in range(len(y)):
            x = self.model(y[i])
            x = x.pooler_output
            # x = x.mean(dim=1)
            l.append(x)
            if i == 0:
                continue
            z.append(x - l[i-1])
            # z.append(x)
        x = torch.stack(z, 0).permute(1, 0, 2)
        x = self.time_net(x)
        x = self.sigmoid(x)
        return x

    @staticmethod
    def freeze(model):
        for param in model.parameters():
            param.requires_grad = False

class Dino_Cosine_Tracker(nn.Module):
    def __init__(self, model, time_net):
        super().__init__()
        self.model = model
        self.time_net = time_net
        self.sigmoid = nn.Sigmoid()
        self.freeze(self.model)
        self.cos = nn.CosineSimilarity(dim=1)

    def forward(self, x):
        y = torch.unbind(x,dim=2)
        z = []
        for i in range(len(y)):
            x = self.model(y[i])
            x = x.last_hidden_state
            if i == 0:
                ori = x
                continue
            z.append(self.cos(x, ori))
            # z.append(x - ori)
            # z.append(x)
        x = torch.stack(z, 0).permute(1, 0, 2)
        x = self.time_net(x)
        x = self.sigmoid(x)
        return x

    @staticmethod
    def freeze(model):
        for param in model.parameters():
            param.requires_grad = False

class Dino_Clip(nn.Module):
    def __init__(self, model, num_classes=2):
        super().__init__()
        self.model = model
        self.sigmoid = nn.Sigmoid()
        self.freeze(self.model)
        self.fc = nn.Linear(in_features=1024, out_features=num_classes, bias=True)

    def forward(self, x):
        x = self.model(x)
        x = x.pooler_output
        x = self.fc(x)
        return x

    @staticmethod
    def freeze(model):
        for param in model.parameters():
            param.requires_grad = False

class Dino_Baseline(nn.Module):
    def __init__(self, model, num_classes=2):
        super().__init__()
        self.model = model
        self.sigmoid = nn.Sigmoid()
        self.freeze(self.model)
        self.fc = nn.Linear(in_features=1024, out_features=num_classes, bias=True)

    def forward(self, x):
        x = self.model(x)
        x = x.pooler_output
        x = self.fc(x)
        return x

    @staticmethod
    def freeze(model):
        for param in model.parameters():
            param.requires_grad = False

class Clip_Lora(nn.Module):
    def __init__(self, model, num_classes=2,rank=4):
        super().__init__()
        # self.sigmoid = nn.Sigmoid()
        self.model = model
        self.freeze(self.model)
        lora_config = LoraConfig(
            r=rank,
            lora_alpha=32,
            target_modules=["q_proj", "v_proj"],
            lora_dropout=0.05,
            bias="none",
            # task_type=TaskType.FEATURE_EXTRACTION
        )
        self.model = prepare_model_for_kbit_training(self.model)
        self.model = get_peft_model(self.model, lora_config)
        # print(self.model)
        self.model.print_trainable_parameters()

        self.fc = nn.Linear(in_features=1024, out_features=num_classes, bias=True)

    def forward(self, x):
        x = self.model(x)
        x = x.pooler_output
        x = self.fc(x)
        return x

    @staticmethod
    def freeze(model):
        for param in model.parameters():
            param.requires_grad = False

class Clip_Lora_QKV(nn.Module):
    def __init__(self, model, num_classes=2, rank=4):
        super().__init__()
        # self.sigmoid = nn.Sigmoid()
        self.model = model
        self.freeze(self.model)
        lora_config = LoraConfig(
            r=rank,
            lora_alpha=32,
            target_modules=["q_proj", "k_proj", "v_proj"],
            lora_dropout=0.05,
            bias="none",
            # task_type=TaskType.FEATURE_EXTRACTION
        )
        self.model = prepare_model_for_kbit_training(self.model)
        self.model = get_peft_model(self.model, lora_config)
        # print(self.model)
        self.model.print_trainable_parameters()

        self.fc = nn.Linear(in_features=1024, out_features=num_classes, bias=True)

    def forward(self, x):
        x = self.model(x)
        x = x.pooler_output
        x = self.fc(x)
        return x

    @staticmethod
    def freeze(model):
        for param in model.parameters():
            param.requires_grad = False

class Clip_Lora_Mid(nn.Module):
    def __init__(self, model, num_classes=2):
        super().__init__()
        # self.sigmoid = nn.Sigmoid()
        self.model = model
        self.freeze(self.model)
        lora_config = LoraConfig(
            r=4,
            lora_alpha=32,
            target_modules=["q_proj", "v_proj"],
            lora_dropout=0.05,
            bias="none",
            # task_type=TaskType.FEATURE_EXTRACTION
        )
        self.model = prepare_model_for_kbit_training(self.model)
        self.model = get_peft_model(self.model, lora_config)
        # print(self.model)
        self.model.print_trainable_parameters()
        self.fc = nn.Linear(in_features=1024, out_features=1024, bias=True)
        self.fc1 = nn.Linear(in_features=1024, out_features=num_classes, bias=True)

    def forward(self, x):
        x = self.model(x)
        x = x.pooler_output
        emb = self.fc(x)
        x = self.fc1(emb)
        return x, emb

    @staticmethod
    def freeze(model):
        for param in model.parameters():
            param.requires_grad = False

class Clip_Lora_Seg(nn.Module):
    def __init__(self, model, num_classes=2, rank=4):
        super().__init__()
        # self.sigmoid = nn.Sigmoid()
        self.model = model
        self.freeze(self.model)
        lora_config = LoraConfig(
            r=rank,
            lora_alpha=32,
            target_modules=["q_proj", "v_proj"],
            lora_dropout=0.05,
            bias="none",
            # task_type=TaskType.FEATURE_EXTRACTION
        )
        self.model = prepare_model_for_kbit_training(self.model)
        self.model = get_peft_model(self.model, lora_config)
        # print(self.model)
        self.model.print_trainable_parameters()

        self.fc = nn.Linear(in_features=1024, out_features=num_classes, bias=True)
        # self.adapt = nn.Linear(in_features=1024, out_features=256, bias=True)

    def forward(self, x):
        x = self.model(x)
        if x.last_hidden_state.dim() == 3:
            emb = x.last_hidden_state[:,1:,:]
        else:
            emb = x.last_hidden_state[1:,:]
        x = x.pooler_output
        x = self.fc(x)
        # emb = self.adapt(emb)
        return x, emb

    @staticmethod
    def freeze(model):
        for param in model.parameters():
            param.requires_grad = False

class Clip_Lora_Seg_QKV(nn.Module):
    def __init__(self, model, num_classes=2):
        super().__init__()
        # self.sigmoid = nn.Sigmoid()
        self.model = model
        self.freeze(self.model)
        lora_config = LoraConfig(
            r=4,
            lora_alpha=32,
            target_modules=["q_proj", "k_proj", "v_proj"],
            lora_dropout=0.05,
            bias="none",
            # task_type=TaskType.FEATURE_EXTRACTION
        )
        self.model = prepare_model_for_kbit_training(self.model)
        self.model = get_peft_model(self.model, lora_config)
        # print(self.model)
        self.model.print_trainable_parameters()

        self.fc = nn.Linear(in_features=1024, out_features=num_classes, bias=True)
        # self.adapt = nn.Linear(in_features=1024, out_features=256, bias=True)

    def forward(self, x):
        x = self.model(x)
        if x.last_hidden_state.dim() == 3:
            emb = x.last_hidden_state[:,1:,:]
        else:
            emb = x.last_hidden_state[1:,:]
        x = x.pooler_output
        x = self.fc(x)
        # emb = self.adapt(emb)
        return x, emb

    @staticmethod
    def freeze(model):
        for param in model.parameters():
            param.requires_grad = False

class Dino_Lora(nn.Module):
    def __init__(self, model, num_classes=2):
        super().__init__()
        # self.sigmoid = nn.Sigmoid()
        self.model = model
        self.freeze(self.model)
        lora_config = LoraConfig(
            r=4,
            lora_alpha=32,
            target_modules=["query", "value"],
            lora_dropout=0.05,
            bias="none",
            # task_type=TaskType.FEATURE_EXTRACTION
        )
        self.model = prepare_model_for_kbit_training(self.model)
        self.model = get_peft_model(self.model, lora_config)
        # print(self.model)
        self.model.print_trainable_parameters()

        self.fc = nn.Linear(in_features=1024, out_features=num_classes, bias=True)

    def forward(self, x):
        x = self.model(x)
        x = x.pooler_output
        x = self.fc(x)
        return x

    @staticmethod
    def freeze(model):
        for param in model.parameters():
            param.requires_grad = False

class Dino_Lora_Mid(nn.Module):
    def __init__(self, model, num_classes=2):
        super().__init__()
        # self.sigmoid = nn.Sigmoid()
        self.model = model
        self.freeze(self.model)
        lora_config = LoraConfig(
            r=4,
            lora_alpha=32,
            target_modules=["query", "value"],
            lora_dropout=0.05,
            bias="none",
            # task_type=TaskType.FEATURE_EXTRACTION
        )
        self.model = prepare_model_for_kbit_training(self.model)
        self.model = get_peft_model(self.model, lora_config)
        # print(self.model)
        self.model.print_trainable_parameters()

        self.fc = nn.Linear(in_features=1024, out_features=1024, bias=True)
        self.fc1 = nn.Linear(in_features=1024, out_features=num_classes, bias=True)

    def forward(self, x):
        x = self.model(x)
        x = x.pooler_output
        emb = self.fc(x)
        x = self.fc1(emb)
        return x, emb

    @staticmethod
    def freeze(model):
        for param in model.parameters():
            param.requires_grad = False

class Dino_Lora_Seg(nn.Module):
    def __init__(self, model, num_classes=2,rank=4):
        super().__init__()
        # self.sigmoid = nn.Sigmoid()
        self.model = model
        self.freeze(self.model)
        lora_config = LoraConfig(
            r=rank,
            lora_alpha=32,
            target_modules=["query", "value"],
            lora_dropout=0.05,
            bias="none",
            # task_type=TaskType.FEATURE_EXTRACTION
        )
        self.model = prepare_model_for_kbit_training(self.model)
        self.model = get_peft_model(self.model, lora_config)
        # print(self.model)
        self.model.print_trainable_parameters()

        self.fc = nn.Linear(in_features=1024, out_features=num_classes, bias=True)
        # self.adapt = nn.Linear(in_features=1024, out_features=256, bias=True)

    def forward(self, x):
        x = self.model(x)
        if x.last_hidden_state.dim() == 3:
            emb = x.last_hidden_state[:,1:,:]
        else:
            emb = x.last_hidden_state[1:,:]
        x = x.pooler_output
        x = self.fc(x)
        # emb = self.adapt(emb)
        return x, emb

    @staticmethod
    def freeze(model):
        for param in model.parameters():
            param.requires_grad = False

class Dino_Lora_Seg_CNN(nn.Module):
    def __init__(self, model, num_classes=2, decoder_type="conv-4"):
        super().__init__()
        # self.sigmoid = nn.Sigmoid()
        self.name = "dino"
        self.model = model
        self.decoder_type = decoder_type
        self.freeze(self.model)
        lora_config = LoraConfig(
            r=4,
            lora_alpha=32,
            target_modules=["query", "value"],
            lora_dropout=0.05,
            bias="none",
            # task_type=TaskType.FEATURE_EXTRACTION
        )
        self.model = prepare_model_for_kbit_training(self.model)
        self.model = get_peft_model(self.model, lora_config)
        self.intermidiate_layer_output = None
        self.model.print_trainable_parameters()
        self.kfc = nn.Linear(in_features=1024, out_features=num_classes, bias=True)
        self._set_decoder()

    def _set_decoder(self):
        # Set up decoder architecture
        upscaling_layers = []
        if "conv" in self.decoder_type:
            filter_sizes = [1024, 512, 256, 128, 64]
            num_convs = int(re.search(r'\d{0,3}$', self.decoder_type).group())
            
            for i in range(1, len(filter_sizes)):
                upscaling_layers.append(nn.Conv2d(filter_sizes[i-1], filter_sizes[i], kernel_size=5, padding=2))
                upscaling_layers.append(nn.BatchNorm2d(filter_sizes[i]))
                upscaling_layers.append(nn.ReLU())
                for _ in range(num_convs//4 - 1):
                    upscaling_layers.append(nn.Conv2d(filter_sizes[i], filter_sizes[i], kernel_size=5, padding=2))
                    upscaling_layers.append(nn.BatchNorm2d(filter_sizes[i]))
                    upscaling_layers.append(nn.ReLU())

                # skip some upscaling layers if the input is too large (case for CNNs)
                skip_upscaling = (
                    self.intermidiate_layer_output == "layer2" and i == 1
                    or self.intermidiate_layer_output == "layer1" and i <= 2
                    ) and ("RN50" in self.name or "xceptionnet" in self.name)
                if skip_upscaling:
                    continue

                upscaling_layers.append(nn.Upsample(scale_factor=2, mode='bilinear'))

            # CNNs output may not be in (256, 256) - usually a (224, 224) size
            if "RN50" in self.name or "xceptionnet" in self.name:
                upscaling_layers.append(nn.Upsample(size=(256, 256), mode='bilinear'))

            upscaling_layers.append(nn.Conv2d(64, 1, kernel_size=5, padding=2))

        elif self.decoder_type == "linear":
            # Xceptionnet
            if self.name == "xceptionnet":
                upscaling_layers.append(nn.Linear(784, 1))
            # CLIP
            else:
                upscaling_layers.append(nn.Linear(1024, 1))

        # elif self.decoder_type == "attention":
        #     transformer_width = 1024
        #     transformer_heads = transformer_width // 64
        #     attn_mask = self._build_attention_mask()
        #     self.att1 = ResidualAttentionBlock(transformer_width, transformer_heads, attn_mask)
        #     self.att2 = ResidualAttentionBlock(transformer_width, transformer_heads, attn_mask)
        #     upscaling_layers.append(nn.Linear(1024, 1))

        self.fc = nn.Sequential(*upscaling_layers)

    def _feature_map_transform(self, input):
        output = input.permute(0, 2, 1)
        output = output.view(output.size()[0], output.size()[1], int(output.size()[2]**0.5), int(output.size()[2]**0.5))
        return output

    def forward(self, x):
        x = self.model(x)
        if x.last_hidden_state.dim() == 3:
            x = x.last_hidden_state[:,1:,:]
        else:
            x = x.last_hidden_state[1:,:]
        emb = x
        score = self.kfc(x)
        x = self._feature_map_transform(x)
        x = self.fc(x)
        x = torch.flatten(x, start_dim =1)
        return x, score, emb

    @staticmethod
    def freeze(model):
        for param in model.parameters():
            param.requires_grad = False

class Dino_Lora_handcrafted(nn.Module):
    def __init__(self, model, num_classes=2):
        super().__init__()
        self.model = model
        self.sigmoid = nn.Sigmoid()
        self.freeze(self.model)
        self.fc = nn.Linear(in_features=1024, out_features=num_classes, bias=True)

        for t_layer_i, blk in enumerate(model.blocks):
            # If we only want few lora layer instead of all
            if t_layer_i not in self.lora_layer:
                continue
            w_qkv_linear = blk.attn.qkv
            self.dim = w_qkv_linear.in_features
            w_a_linear_q = nn.Linear(self.dim, r, bias=False)
            w_b_linear_q = nn.Linear(r, self.dim, bias=False)
            w_a_linear_v = nn.Linear(self.dim, r, bias=False)
            w_b_linear_v = nn.Linear(r, self.dim, bias=False)
            self.w_As.append(w_a_linear_q)
            self.w_Bs.append(w_b_linear_q)
            self.w_As.append(w_a_linear_v)
            self.w_Bs.append(w_b_linear_v)
            blk.attn.qkv = _LoRA_qkv(
                w_qkv_linear,
                w_a_linear_q,
                w_b_linear_q,
                w_a_linear_v,
                w_b_linear_v,
            )
        self.reset_parameters()

    def forward(self, x):
        x = self.model(x)
        x = x.pooler_output
        x = self.fc(x)
        return x

    @staticmethod
    def freeze(model):
        for param in model.parameters():
            param.requires_grad = False

    @staticmethod
    def reset_parameters(self) -> None:
        for w_A in self.w_As:
            nn.init.kaiming_uniform_(w_A.weight, a=math.sqrt(5))
        for w_B in self.w_Bs:
            nn.init.zeros_(w_B.weight)

class _LoRA_qkv(nn.Module):
    """In Dinov2 it is implemented as
    self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
    B, N, C = x.shape
    qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
    q, k, v = qkv.unbind(0)
    """

    def __init__(
            self,
            qkv: nn.Module,
            linear_a_q: nn.Module,
            linear_b_q: nn.Module,
            linear_a_v: nn.Module,
            linear_b_v: nn.Module,
    ):
        super().__init__()
        self.qkv = qkv
        self.linear_a_q = linear_a_q
        self.linear_b_q = linear_b_q
        self.linear_a_v = linear_a_v
        self.linear_b_v = linear_b_v
        self.dim = qkv.in_features
        self.w_identity = torch.eye(qkv.in_features)

    def forward(self, x):
        qkv = self.qkv(x)  # B,N,3*org_C
        new_q = self.linear_b_q(self.linear_a_q(x))
        new_v = self.linear_b_v(self.linear_a_v(x))
        
        qkv[:, :, : self.dim] += new_q
        qkv[:, :, -self.dim:] += new_v
        return qkv
