import copy
import logging
import torch
from torch import nn
import math
import timm
import numpy as np
import torch.nn.functional as F


class BaseNet(nn.Module):
    def __init__(self, args, backbone, pretrained):
        super(BaseNet, self).__init__()
        self.convnet = backbone
        self.fc = None

    @property
    def feature_dim(self):
        return self.convnet.ln_final.normalized_shape[0]

    def extract_vector(self, x):
        return self.convnet(x)["features"]

    def forward(self, x):
        x = self.convnet(x)
        out = self.fc(x["features"])
        """
        {
            'fmaps': [x_1, x_2, ..., x_n],
            'features': features
            'logits': logits
        }
        """
        out.update(x)
        return out

    def update_fc(self, nb_classes):
        pass

    def generate_fc(self, in_dim, out_dim):
        pass

    def copy(self):
        return copy.deepcopy(self)

    def freeze(self):
        for param in self.parameters():
            param.requires_grad = False
        self.eval()
        return self
    

def reduce_proxies(out, nb_proxy):
    if nb_proxy == 1:
        return out
    bs = out.shape[0]
    nb_classes = out.shape[1] / nb_proxy
    assert nb_classes.is_integer(), 'Shape error'
    nb_classes = int(nb_classes)

    simi_per_class = out.view(bs, nb_classes, nb_proxy)
    attentions = F.softmax(simi_per_class, dim=-1)

    return (attentions * simi_per_class).sum(-1)

class CosineLinear(nn.Module):
    def __init__(self, in_features, out_features, nb_proxy=1, to_reduce=False, sigma=True):
        super(CosineLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features * nb_proxy
        self.nb_proxy = nb_proxy
        self.to_reduce = to_reduce
        self.weight = nn.Parameter(torch.Tensor(self.out_features, in_features))
        if sigma:
            self.sigma = nn.Parameter(torch.Tensor(1))
        else:
            self.register_parameter('sigma', None)
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)
        if self.sigma is not None:
            self.sigma.data.fill_(1)

    def forward(self, input):
        out = F.linear(F.normalize(input, p=2, dim=1), F.normalize(self.weight, p=2, dim=1))

        if self.to_reduce:
            # Reduce_proxy
            out = reduce_proxies(out, self.nb_proxy)

        if self.sigma is not None:
            out = self.sigma * out

        return {'logits': out}



class SimpleClipNet(BaseNet):
    def __init__(self, args, backbone, pretrained):
        super().__init__(args, backbone, pretrained)
        # self.convnet, self.preprocess, self.tokenizer = get_convnet(args, pretrained)
        self.convnet = backbone
        self.class_name = 'SimpleClipNet'
        self.args = args


    def update_fc(self, nb_classes, nextperiod_initialization=None):
        fc = self.generate_fc(self.feature_dim, nb_classes).cuda()
        if self.fc is not None:
            nb_output = self.fc.out_features
            weight = copy.deepcopy(self.fc.weight.data)
            fc.sigma.data = self.fc.sigma.data
            if nextperiod_initialization is not None:
                weight = torch.cat([weight, nextperiod_initialization])
            else:
                weight = torch.cat([weight, torch.zeros(nb_classes - nb_output, self.feature_dim).cuda()])
            fc.weight = nn.Parameter(weight)
        del self.fc
        self.fc = fc

    def generate_fc(self, in_dim, out_dim):
        fc = CosineLinear(in_dim, out_dim)
        return fc

    def extract_vector(self, x):
        return self.convnet.encode_image(x)

    def encode_image(self, x):
        return self.convnet.encode_image(x)
    
    def encode_text(self, x):
        return self.convnet.encode_text(x)

    def forward(self, img, text):

        image_features, text_features, logit_scale=self.convnet(img, text)
        return image_features, text_features, logit_scale

    def re_initiate(self):
        print('re-initiate model')
        # self.convnet, self.preprocess, self.tokenizer = get_convnet(self.args, True)
        # self.convnet = 


def get_attribute(dic,name,default):
    if name in dic:
        return dic[name]
    else:
        print(name, 'not in args, set to', default, ' as default')
        return default
    
class Proj_Pure_MLP(nn.Module):
    def __init__(self,in_features, out_features, middle_dim):
        super(Proj_Pure_MLP, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.MLP= nn.Sequential(
            nn.Linear(in_features, out_features),
        )
    
    def forward(self, input):
        out = self.MLP(input)
        return out



class ScaledDotProductAttention(nn.Module):
    ''' Scaled Dot-Product Attention '''

    def __init__(self, temperature, attn_dropout=0.1):
        super().__init__()
        self.temperature = temperature
        self.dropout = nn.Dropout(attn_dropout)
        self.softmax = nn.Softmax(dim=2)

    def forward(self, q, k, v):
        attn = torch.bmm(q, k.transpose(1, 2))
        attn = attn / self.temperature
        log_attn = F.log_softmax(attn, 2)
        attn = self.softmax(attn)
        attn = self.dropout(attn)
        output = torch.bmm(attn, v)
        return output, attn, log_attn


class MultiHeadAttention(nn.Module):
    ''' Multi-Head Attention module '''

    def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
        super().__init__()
        self.n_head = n_head
        self.d_k = d_k
        self.d_v = d_v

        self.w_qs = nn.Linear(d_model, n_head * d_k, bias=False)
        self.w_ks = nn.Linear(d_model, n_head * d_k, bias=False)
        self.w_vs = nn.Linear(d_model, n_head * d_v, bias=False)
        nn.init.normal_(self.w_qs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k)))
        nn.init.normal_(self.w_ks.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k)))
        nn.init.normal_(self.w_vs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_v)))

        self.attention = ScaledDotProductAttention(temperature=np.power(d_k, 0.5))
        self.layer_norm = nn.LayerNorm(d_model)

        self.fc = nn.Linear(n_head * d_v, d_model)
        nn.init.xavier_normal_(self.fc.weight)
        self.dropout = nn.Dropout(dropout)

    def forward(self, q, k, v):
        d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
        sz_b, len_q, _ = q.size()
        sz_b, len_k, _ = k.size()
        sz_b, len_v, _ = v.size()

        residual = q
        q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)
        k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)
        v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)

        q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_q, d_k)  # (n*b) x lq x dk
        k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_k, d_k)  # (n*b) x lk x dk
        v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_v, d_v)  # (n*b) x lv x dv

        output, attn, log_attn = self.attention(q, k, v)

        output = output.view(n_head, sz_b, len_q, d_v)
        output = output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_q, -1)  # b x lq x (n*dv)

        output = self.dropout(self.fc(output))
        output = self.layer_norm(output + residual)

        return output


class Proof_Net(SimpleClipNet):
    def __init__(self, args, backbone, pretrained):
        super().__init__(args, backbone, pretrained)
        self.projs_img = nn.ModuleList()
        self.projs_text = nn.ModuleList()
        self.args = args
        self.projtype = get_attribute(self.args, 'projection_type', 'pure_mlp')
        self.context_prompt_length_per_task = get_attribute(self.args, 'context_prompt_length_per_task', 3)
        self.sel_attn = MultiHeadAttention(1, self.feature_dim, self.feature_dim, self.feature_dim, dropout=0.1)
        self.img_prototypes = None

        self.context_prompts = nn.ParameterList()

    def update_prototype(self, nb_classes):
        if self.img_prototypes is not None:
            nb_output = len(self.img_prototypes)
            self.img_prototypes = torch.cat([copy.deepcopy(self.img_prototypes), torch.zeros(nb_classes - nb_output, self.feature_dim).to("cuda")])
        else:
            self.img_prototypes = torch.zeros(nb_classes, self.feature_dim)
        print('update prototype, now we have {} prototypes'.format(self.img_prototypes.shape[0]))
    
    def update_context_prompt(self):
        for i in range(len(self.context_prompts)):
            self.context_prompts[i].requires_grad = False
        self.context_prompts.append(nn.Parameter(torch.randn(self.context_prompt_length_per_task, self.feature_dim)))
        print('update context prompt, now we have {} context prompts'.format(len(self.context_prompts) * self.context_prompt_length_per_task))
        self.context_prompts
    
    def get_context_prompts(self):
        return torch.cat([item for item in self.context_prompts], dim=0)

    def encode_image(self, x, normalize: bool = False):
        basic_img_features = self.convnet.encode_image(x)
        img_features = [proj(basic_img_features.float()) for proj in self.projs_img]
        img_features = torch.stack(img_features, dim=1)#[bs,num_proj,dim]
        image_feas = torch.sum(img_features, dim=1)#[bs,dim]
        return F.normalize(image_feas, dim=-1) if normalize else image_feas
        
    def encode_text(self, x, normalize: bool = False):
        # x = x.to(self._device)
        # basic_text_features = self.convnet.encode_text(x)
        basic_text_features = x
        text_features = [proj(basic_text_features.float()) for proj in self.projs_text]
        text_features = torch.stack(text_features, dim=1)
        text_feas = torch.sum(text_features, dim=1) #[bs,dim]
        return F.normalize(text_feas, dim=-1) if normalize else text_feas
        
    def encode_prototpyes(self, normalize: bool = False):
        self.img_prototypes=self.img_prototypes.to("cuda")
        img_features = [proj(self.img_prototypes.float()) for proj in self.projs_img]
        img_features=torch.stack(img_features, dim=1)#[nb_class,num_proj,dim]
        image_feas=torch.sum(img_features, dim=1)#[nb_class,dim]
        return F.normalize(image_feas, dim=-1) if normalize else image_feas

    def extend_task(self):
        self.projs_img.append(self.extend_item())
        self.projs_text.append(self.extend_item())

    def extend_item(self):
        if self.projtype=='pure_mlp':
            return Proj_Pure_MLP(self.feature_dim,self.feature_dim,self.feature_dim)
        else:
            raise NotImplementedError
    
    def forward(self, image, text):
        # print(image.size())
        image_features = self.encode_image(image, normalize=True)#bs,dim
        text_features = self.encode_text(text, normalize=True)#bs,dim

        prototype_features = self.encode_prototpyes(normalize=True) #nb_class,dim
        context_prompts=self.get_context_prompts() # num_prompt, dim

        len_texts=text_features.shape[0]
        len_protos=prototype_features.shape[0]
        len_context_prompts=context_prompts.shape[0]
        # restack the features and pass them through the attention layer
        image_features = image_features.view(image_features.shape[0], -1, self.feature_dim)#bs,1,dim
        text_features = text_features.view(text_features.shape[0], self.feature_dim)#num_text,dim
        prototype_features = prototype_features.view(prototype_features.shape[0], self.feature_dim)#len_proto,dim
        context_prompts = context_prompts.view(context_prompts.shape[0], self.feature_dim)#len_con,dim
        # expand text features to be the same dim as image features
        text_features = text_features.expand(image_features.shape[0], text_features.shape[0], self.feature_dim)#bs,num_text,dim
        prototype_features = prototype_features.expand(image_features.shape[0], prototype_features.shape[0], self.feature_dim)#bs,len_proto,dim
        context_prompts = context_prompts.expand(image_features.shape[0], context_prompts.shape[0], self.feature_dim)#bs,len_con,dim
        # concat them together
        # features = torch.cat([image_features, text_features, prototype_features], dim=1) # bsize * (1+num_texts+num_protos) * dim
        features = torch.cat([image_features, text_features, prototype_features, context_prompts], dim=1) # bsize * (1+num_texts+num_protos+num_context) * dim
        # pass through the attention layer
        features = self.sel_attn(features, features, features)
        # split them back, image features are the first half, text features are the second half
        # image_features, text_features = torch.split(features, features.shape[1] // 2, dim=1)
        image_features = features[:, 0, :] # bsize * dim
        text_features = features[:, 1:len_texts+1, :] # bsize * num_texts * dim
        prototype_features = features[:, len_texts+1:len_texts+1+len_protos, :] # bsize * num_protos * dim 
        context_prompts = features[:, len_texts+1+len_protos:len_texts+1+len_protos+len_context_prompts, :] # bsize * num_context * dim
        # remove the 0-th dimension of text features to be num_texts * dim
        text_features = torch.mean(text_features, dim=0) # num_texts * dim
        prototype_features = torch.mean(prototype_features, dim=0) # num_protos * dim
        # squeeze
        image_features = image_features.view(image_features.shape[0], -1)
        text_features = text_features.view(text_features.shape[0], -1)
        prototype_features = prototype_features.view(prototype_features.shape[0], -1)
        return image_features, text_features, self.convnet.logit_scale.exp(), prototype_features
    
    def forward_transformer(self, image_features, text_features, transformer=False):
        prototype_features = self.encode_prototpyes(normalize=True)
        if transformer:
            context_prompts = self.get_context_prompts()
            len_texts = text_features.shape[0]
            len_protos = prototype_features.shape[0]
            len_context_prompts = context_prompts.shape[0]
            # restack the features and pass them through the attention layer
            image_features = image_features.view(image_features.shape[0], -1, self.feature_dim) #[bs, 1, dim]
            text_features = text_features.view(text_features.shape[0], self.feature_dim) #[total_classes, dim]
            prototype_features = prototype_features.view(prototype_features.shape[0], self.feature_dim) #[len_pro, dim]
            context_prompts = context_prompts.view(context_prompts.shape[0], self.feature_dim) #[len_con_pro, dim]
            # expand text features to be the same dim as image features
            text_features = text_features.expand(image_features.shape[0], text_features.shape[0], self.feature_dim) #[bs, total_classes, dim]
            prototype_features = prototype_features.expand(image_features.shape[0], prototype_features.shape[0], self.feature_dim) #[bs, len_pro, dim]
            context_prompts = context_prompts.expand(image_features.shape[0], context_prompts.shape[0], self.feature_dim) #[bs, len_con_pro, dim]
            # concat them together
            # features = torch.cat([image_features, text_features, prototype_features], dim=1) # bsize * (1+num_texts+num_protos) * dim
            features = torch.cat([image_features, text_features, prototype_features, context_prompts], dim=1) # bsize * (1+num_texts+num_protos+num_context) * dim
            # pass through the attention layer
            features = self.sel_attn(features, features, features)
            # split them back, image features are the first half, text features are the second half
            # image_features, text_features = torch.split(features, features.shape[1] // 2, dim=1)
            image_features = features[:, 0, :] # bsize * dim
            text_features = features[:, 1:len_texts+1, :] # bsize * num_texts * dim
            prototype_features = features[:, len_texts+1:len_texts+1+len_protos, :] # bsize * num_protos * dim 
            context_prompts = features[:, len_texts+1+len_protos:len_texts+1+len_protos+len_context_prompts, :] # bsize * num_context * dim
            # remove the 0-th dimension of text features to be num_texts * dim
            text_features = torch.mean(text_features, dim=0) # num_texts * dim
            prototype_features = torch.mean(prototype_features, dim=0) # num_protos * dim
            # squeeze
            image_features = image_features.view(image_features.shape[0], -1)
            text_features = text_features.view(text_features.shape[0], -1)
            prototype_features = prototype_features.view(prototype_features.shape[0], -1)
            return image_features, text_features, self.convnet.logit_scale.exp(), prototype_features
        else:
            return image_features, text_features, self.convnet.logit_scale.exp(), prototype_features
    
    
    def freeze_projection_weight_new(self):
        if len(self.projs_img)>1:
            for i in range(len(self.projs_img)):
                for param in self.projs_img[i].parameters():
                    param.requires_grad = False
                for param in self.projs_text[i].parameters():
                    param.requires_grad = True
            for param in self.projs_img[-1].parameters():
                param.requires_grad = True
        for param in self.sel_attn.parameters():
            param.requires_grad = True


class ClipLoss(nn.Module):

    def __init__(
            self,
            local_loss=False,
            gather_with_grad=False,
            cache_labels=False,
            rank=0,
            world_size=1,
            use_horovod=False,
    ):
        super().__init__()
        self.local_loss = local_loss
        self.gather_with_grad = gather_with_grad
        self.cache_labels = cache_labels
        self.rank = rank
        self.world_size = world_size
        self.use_horovod = use_horovod

        # cache state
        self.prev_num_logits = 0
        self.labels = {}

    def get_ground_truth(self, device, num_logits) -> torch.Tensor:
        # calculated ground-truth and cache if enabled
        if self.prev_num_logits != num_logits or device not in self.labels:
            labels = torch.arange(num_logits, device=device, dtype=torch.long)
            if self.world_size > 1 and self.local_loss:
                labels = labels + num_logits * self.rank
            if self.cache_labels:
                self.labels[device] = labels
                self.prev_num_logits = num_logits
        else:
            labels = self.labels[device]
        return labels

    def get_logits(self, image_features, text_features, logit_scale):
        if self.world_size > 1:
            all_image_features, all_text_features = gather_features(
                image_features, text_features,
                self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod)

            if self.local_loss:
                logits_per_image = logit_scale * image_features @ all_text_features.T
                logits_per_text = logit_scale * text_features @ all_image_features.T
            else:
                logits_per_image = logit_scale * all_image_features @ all_text_features.T
                logits_per_text = logits_per_image.T
        else:
            logits_per_image = logit_scale * image_features @ text_features.T
            logits_per_text = logit_scale * text_features @ image_features.T
        
        return logits_per_image, logits_per_text

    def forward(self, image_features, text_features, logit_scale, output_dict=False):
        device = image_features.device
        logits_per_image, logits_per_text = self.get_logits(image_features, text_features, logit_scale)

        labels = self.get_ground_truth(device, logits_per_image.shape[0])

        total_loss = (
            F.cross_entropy(logits_per_image, labels) +
            F.cross_entropy(logits_per_text, labels)
        ) / 2

        return {"contrastive_loss": total_loss} if output_dict else total_loss