import numpy as np
import scipy.io as sio
import torch
from sklearn import preprocessing
from torch.utils.data import Dataset
from utility.utils import *

from torch import Tensor
import torch.nn as nn
import re
import clip
import pandas as pd
import sys

class cos_sim_loss(nn.MSELoss):
    __constants__ = ['reduction']

    def __init__(self, dim=1, size_average=None, reduce=None, reduction: str = 'mean') -> None:
        super(cos_sim_loss, self).__init__(size_average, reduce, reduction)
        assert reduction in ['none', None, 'mean', 'sum']
        self.reduction = reduction
        self.cos = nn.CosineSimilarity(dim=dim, eps=1e-8)

    def forward(self, input: Tensor, target: Tensor) -> Tensor:
        loss = 1-self.cos(input, target)
        if self.reduction == 'none' or self.reduction == None:
            return loss.unsqueeze(-1)
        elif self.reduction == 'mean':
            return loss.mean()
        else:
            return loss.sum()


def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Linear') != -1:
        m.weight.data.normal_(0.0, 0.02)
        m.bias.data.fill_(0)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)


def map_label(label, classes):
    mapped_label = torch.LongTensor(label.size())
    for i in range(classes.size(0)):
        mapped_label[label == classes[i]] = i

    return mapped_label

def map_label_extend(label, new_classes, base_classes):
    mapped_label = torch.LongTensor(label.size())
    for i in range(new_classes.size(0)):
        mapped_label[label == new_classes[i]] = i + len(base_classes)
    return mapped_label


def reverse_map_label(label, classes):
    mapped_label = torch.LongTensor(label.size())
    for i in range(classes.size(0)):
        mapped_label[label == i] = classes[i]
    return mapped_label


def reverse_map_label_extend(label, new_classes, base_classes):
    mapped_label = torch.LongTensor(label.size())
    for i in range(new_classes.size(0)):
        mapped_label[label == (i + len(base_classes))] = new_classes[i]
    return mapped_label

class GenericDataset(Dataset):
    def __init__(self, opt, _input, _target, cuda, transform=None):
        assert len(_input) == len(_target)
        self.opt = opt
        self.input = _input
        self.target = _target
        self.transform = transform
        self.cuda = cuda

    def __len__(self):
        return len(self.input)

    def __getitem__(self, idx):
        in_var = self.input[idx]
        target = self.target[idx]

        if self.cuda:
            in_var = in_var.cuda()
            target = target.cuda()

        if self.transform:
            in_var = self.transform(in_var)

        return in_var, target

class GenericDatasetINV(Dataset):
    def __init__(self, opt, _input, _target, _input_inv, _target_inv, cuda, transform=None):
        assert len(_input) == len(_target)
        self.opt = opt
        self.input = _input
        self.target = _target
        self.input_inv = _input_inv
        self.target_inv = _target_inv
        self.transform = transform
        self.cuda = cuda

    def __len__(self):
        return len(self.input)

    def __getitem__(self, idx):
        in_var = self.input[idx]
        target = self.target[idx]
        in_var_inv = self.input_inv[idx]
        target_inv = self.target_inv[idx]

        if self.cuda:
            in_var = in_var.cuda()
            target = target.cuda()
            in_var_inv = in_var_inv.cuda()
            target_inv = target_inv.cuda()

        if self.transform:
            in_var = self.transform(in_var)

        return in_var, target, in_var_inv, target_inv

class Logger(object):
    def __init__(self, filename):
        self.filename = filename
        f = open(self.filename+'.log', "a")
        f.close()

    def write(self, message):
        f = open(self.filename+'.log', "a")
        f.write(message)
        f.close()


def get_mean_features(features, labels):
    mean_features = []
    features = np.array(features)
    labels = np.array(labels)
    for label in range(50):
        indices = np.where(labels == label)[0]
        mean_feature = np.mean(features[indices], axis=0)
        mean_features.append(mean_feature)
    mean_features = np.array(mean_features)
    return mean_features


class DATA_LOADER(object):
    def __init__(self, opt):
        if opt.matdataset:
            self.read_matdataset(opt)
        self.index_in_epoch = 0
        self.epochs_completed = 0

    def read_matdataset(self, opt):
        matcontent = sio.loadmat(f"{opt.dataroot}/{opt.dataset}/{opt.image_embedding}.mat")

        # visual feature
        feature = matcontent['features'].T
        if opt.image_embedding[:10] in ['pretrained']:
            feature = feature.T
        label = matcontent['labels'].astype(int).squeeze() - 1

        mat_path = opt.dataroot + "/" + opt.dataset + "/" + "att_splits.mat"
        matcontent = sio.loadmat(mat_path)

        selected_view_num_multi = read_lines(f"{opt.rootpath}/prompt_view/aux_info/select_{opt.dataset}.txt")
        selected_view_num_multi = [sel.split(",") for sel in selected_view_num_multi]
        for i in range(opt.selected_view_level):
            selected_view_num = selected_view_num_multi[i-1]
        if opt.selected_view_level == 0:
            selected_view_num = list(range(opt.view_num+1))
        else:
            selected_view_num = [0]+[int(sel) for sel in selected_view_num]
        if opt.view_num > len(selected_view_num):
            opt.view_num = len(selected_view_num)-1

        # numpy array index starts from 0, matlab starts from 1
        trainval_loc = matcontent['trainval_loc'].squeeze() - 1
        train_loc = matcontent['train_loc'].squeeze() - 1
        val_unseen_loc = matcontent['val_loc'].squeeze() - 1
        test_seen_loc = matcontent['test_seen_loc'].squeeze() - 1
        test_unseen_loc = matcontent['test_unseen_loc'].squeeze() - 1

        # transfer
        if opt.zst:
            if opt.expert == 'wiki2vec':
                transfer_path = opt.rootpath + "/embeddings/wiki2vec/" + opt.dataset + "_wiki_sum_list.npy"
                if opt.zstfrom == 'imagenet':
                    source_path = opt.rootpath + "/embeddings/wiki2vec/imgnet_wiki_list.npy"
                else:
                    source_path = opt.rootpath + "/embeddings/wiki2vec/" + opt.zstfrom + "_wiki_sum_list.npy"
            elif opt.expert == 'cn':
                transfer_path = opt.rootpath + "/embeddings/conceptnet/" + opt.dataset + "_cn_sum_list.npy"
                if opt.zstfrom == 'imagenet':
                    source_path = opt.rootpath + "/embeddings/conceptnet/imgnet_cn_list.npy"
                else:
                    source_path = opt.rootpath + "/embeddings/conceptnet/" + opt.zstfrom + "_cn_sum_list.npy"
            
            # target
            transfer_attributes = torch.from_numpy(np.load(transfer_path, allow_pickle=True)).float()
            transfer_attributes /= torch.norm(transfer_attributes, dim=1)[:, None]
            # source
            source_attributes = torch.from_numpy(np.load(source_path, allow_pickle=True)).float()
            source_attributes /= torch.norm(source_attributes, dim=1)[:, None]
            attribute0 = torch.cat((source_attributes, transfer_attributes)) 
            opt.wordemb_dim = 512
            all_classes_num = attribute0.shape[0]

            if opt.factual_branch == 'attention':
                embedding_path = f"{opt.rootpath}/embeddings/{opt.class_embedding}/ImageNet_{opt.dataset}_{opt.llm}_{opt.class_embedding}.npy"
                if opt.conclude_inv:
                    if opt.inv_merge:
                        embedding_path_inv = f"{opt.rootpath}/embeddings/{opt.class_embedding}/ImageNet_{opt.dataset}_{opt.llm}_{opt.class_embedding}_inv_merge.npy"
                    else:
                        embedding_path_inv = f"{opt.rootpath}/embeddings/{opt.class_embedding}/ImageNet_{opt.dataset}_{opt.llm}_{opt.class_embedding}_inv.npy"
                
                if opt.selected_view_level != 0:
                    self.attribute = self.attribute[:, selected_view_num]
                elif opt.view_num+1 < self.attribute.shape[1]:
                    self.attribute = self.attribute[:, :opt.view_num+1]
                else:
                    print(f"Warning: view_num {opt.view_num} is larger than the attribute dimension {self.attribute.shape[1]}")
                    opt.view_num = self.attribute.shape[1]-1
                
                self.attribute = torch.from_numpy(np.load(embedding_path, allow_pickle=True)).float()
                if opt.conclude_inv:
                    self.attribute_inv = torch.from_numpy(np.load(embedding_path_inv, allow_pickle=True)).float()
                    if opt.inv_merge:
                        self.attribute = torch.cat([self.attribute, self.attribute_inv], dim=0)
                
                self.attribute = self.attribute / self.attribute.norm(dim=-1, keepdim=True)
                self.attribute = self.attribute.reshape(self.attribute.shape[0], -1)

                if opt.conclude_inv and not opt.inv_merge:
                    self.attribute_inv = self.attribute_inv / self.attribute_inv.norm(dim=-1, keepdim=True)
                    self.attribute_inv = self.attribute_inv.reshape(self.attribute_inv.shape[0], -1)
                
                if opt.conclude_inv:
                    if opt.inv_merge:
                        placeholder_attribute0 = torch.zeros(self.attribute.shape[0] - attribute0.shape[0], attribute0.shape[1])
                        attribute0 = torch.concatenate([attribute0, placeholder_attribute0], axis=0)
                    else:
                        placeholder_attribute0 = torch.zeros(attribute0.shape[0], attribute0.shape[1])
                        self.attribute_inv = torch.concatenate([placeholder_attribute0, self.attribute_inv], axis=1)
                    self.attribute = torch.concatenate([attribute0, self.attribute], axis=1)
                else:
                    self.attribute = torch.concatenate([attribute0, self.attribute[:attribute0.shape[0], :]], axis=1)
            
            else:
                self.attribute = attribute0
            
            self.attribute_f = self.attribute[:all_classes_num,:]
            if opt.conclude_inv and opt.merge:
                self.attribute_inv = self.attribute[all_classes_num:2*all_classes_num,:]
                if opt.concatenation:
                    self.attribute_new = self.attribute
                else:
                    self.attribute_new = torch.concatenate([self.attribute_f, self.attribute_inv], axis=1)
            elif opt.conclude_inv:
                self.attribute_new = self.attribute
        
        # not transfer
        else:
            opt.wordemb_dim = 4096
            embedding_path = f"{opt.rootpath}/embeddings/{opt.class_embedding}/{opt.dataset}_{opt.llm}_{opt.class_embedding}.npy"
            if opt.conclude_inv:
                if opt.inv_merge:
                    embedding_path_inv = f"{opt.rootpath}/embeddings/{opt.class_embedding}/{opt.dataset}_{opt.llm}_{opt.class_embedding}_inv_merge.npy"
                else:
                    embedding_path_inv = f"{opt.rootpath}/embeddings/{opt.class_embedding}/{opt.dataset}_{opt.llm}_{opt.class_embedding}_inv.npy"

            if opt.class_embedding == "clip":
                opt.wordemb_dim = 512
            elif opt.class_embedding == "sbert":
                opt.wordemb_dim = 768
            elif opt.class_embedding == "llama-8b":
                opt.wordemb_dim = 4096
            elif opt.class_embedding == "qwen-7b":
                opt.wordemb_dim = 3584
            else:
                raise ValueError("Invalid embedding model")

            if opt.expert == "att":
                attribute0 = torch.from_numpy(matcontent['att'].T).float()
            elif opt.expert == "wiki2vec":
                expert_embedding_path = f"{opt.rootpath}/embeddings/wiki2vec/{opt.dataset}_wiki_sum_list.npy"
                attribute0 = torch.from_numpy(np.load(expert_embedding_path, allow_pickle=True)).float()
                attribute0 /= torch.norm(attribute0, dim=1)[:, None]
            elif opt.expert == 'cn':
                expert_embedding_path = f"{opt.rootpath}/embeddings/conceptnet/{opt.dataset}_cn_sum_list.npy"
                attribute0 = torch.from_numpy(np.load(expert_embedding_path, allow_pickle=True)).float()
                attribute0 /= torch.norm(attribute0, dim=1)[:, None]
            all_classes_num = attribute0.shape[0]

            if opt.factual_branch == 'mean':
                self.attribute = torch.from_numpy(np.load(embedding_path, allow_pickle=True)).float()
                if opt.conclude_inv:
                    self.attribute_inv = torch.from_numpy(np.load(embedding_path_inv, allow_pickle=True)).float()
                    if opt.inv_merge:
                        self.attribute = torch.cat([self.attribute, self.attribute_inv], dim=0)

                if opt.selected_view_level != 0:
                    self.attribute = self.attribute[:, selected_view_num]
                elif opt.view_num+1 <= self.attribute.shape[1]:
                    self.attribute = self.attribute[:, :opt.view_num+1]
                else:
                    print(f"Warning: view_num {opt.view_num} is larger than the attribute dimension {self.attribute.shape[1]-1}")
                    opt.view_num = self.attribute.shape[1]-1
                
                self.attribute /= torch.norm(self.attribute, dim=1)[:, None]
                self.attribute = self.attribute.mean(dim=1)

                if opt.conclude_inv and not opt.inv_merge:
                    self.attribute_inv /= torch.norm(self.attribute_inv, dim=1)[:, None]
                    self.attribute_inv = self.attribute_inv.mean(dim=1)
                
                if opt.conclude_inv:
                    if opt.inv_merge:
                        placeholder_attribute0 = torch.zeros(self.attribute.shape[0] - attribute0.shape[0], attribute0.shape[1])
                        attribute0 = torch.concatenate([attribute0, placeholder_attribute0], axis=0)
                    else:
                        placeholder_attribute0 = torch.zeros(attribute0.shape[0], attribute0.shape[1])
                        self.attribute_inv = torch.concatenate([placeholder_attribute0, self.attribute_inv], axis=1)
                    self.attribute = torch.concatenate([attribute0, self.attribute], axis=1)
                else:
                    self.attribute = torch.concatenate([attribute0, self.attribute[:attribute0.shape[0], :]], axis=1)

            elif opt.factual_branch == 'attention':
                self.attribute = torch.from_numpy(np.load(embedding_path, allow_pickle=True)).float()
                if opt.conclude_inv:
                    self.attribute_inv = torch.from_numpy(np.load(embedding_path_inv, allow_pickle=True)).float()
                    if opt.inv_merge:
                        self.attribute = torch.cat([self.attribute, self.attribute_inv], dim=0)
                
                if opt.selected_view_level != 0:
                    self.attribute = self.attribute[:, selected_view_num]
                elif opt.view_num+1 < self.attribute.shape[1]:
                    self.attribute = self.attribute[:, :opt.view_num+1]
                else:
                    print(f"Warning: view_num {opt.view_num} is larger than the attribute dimension {self.attribute.shape[1]}")
                    opt.view_num = self.attribute.shape[1]-1
                
                self.attribute = self.attribute / self.attribute.norm(dim=-1, keepdim=True)
                self.attribute = self.attribute.reshape(self.attribute.shape[0], -1)

                if opt.conclude_inv and not opt.inv_merge:
                    self.attribute_inv = self.attribute_inv / self.attribute_inv.norm(dim=-1, keepdim=True)
                    self.attribute_inv = self.attribute_inv.reshape(self.attribute_inv.shape[0], -1)
                
                if opt.conclude_inv:

                    if opt.inv_merge:
                        placeholder_attribute0 = torch.zeros(self.attribute.shape[0] - attribute0.shape[0], attribute0.shape[1])
                        attribute0 = torch.concatenate([attribute0, placeholder_attribute0], axis=0)

                    else:
                        placeholder_attribute0 = torch.zeros(attribute0.shape[0], attribute0.shape[1])
                        self.attribute_inv = torch.concatenate([placeholder_attribute0, self.attribute_inv], axis=1)
                    self.attribute = torch.concatenate([attribute0, self.attribute], axis=1)
                else:
                    self.attribute = torch.concatenate([attribute0, self.attribute[:attribute0.shape[0], :]], axis=1)
            
            else:
                self.attribute = attribute0
            
            self.attribute_f = self.attribute[:all_classes_num,:]

            if opt.conclude_inv and opt.inv_merge:
                self.attribute_inv = self.attribute[all_classes_num:2*all_classes_num,:]
                if opt.concatenation:
                    self.attribute_new = self.attribute
                else:
                    self.attribute_new = torch.concatenate([self.attribute_f, self.attribute_inv], axis=1)
            elif opt.conclude_inv:
                self.attribute_new = self.attribute

        print(f"Loaded factual attribute shape: {self.attribute_f.shape}")
        if opt.conclude_inv:
            print(f"Loaded intervention attribute shape: {self.attribute_inv.shape}")
            if opt.inv_merge:
                print(f"Loaded new constructed attribute shape: {self.attribute_new.shape}")

        # feature and label statistics
        if opt.preprocessing:
            if opt.standardization:
                scaler = preprocessing.StandardScaler()
            else:
                scaler = preprocessing.MinMaxScaler()

            _train_feature = scaler.fit_transform(feature[trainval_loc])
            _test_seen_feature = scaler.transform(feature[test_seen_loc])
            _test_unseen_feature = scaler.transform(feature[test_unseen_loc])
            self.train_feature = torch.from_numpy(_train_feature).float()
            mx = self.train_feature.max()
            self.train_feature.mul_(1/mx)
            self.train_label = torch.from_numpy(label[trainval_loc]).long()
            
            self.test_unseen_feature = torch.from_numpy(_test_unseen_feature).float()
            self.test_unseen_feature.mul_(1/mx)
            self.test_unseen_label = torch.from_numpy(label[test_unseen_loc]).long()
            
            self.test_seen_feature = torch.from_numpy(_test_seen_feature).float()
            self.test_seen_feature.mul_(1/mx)
            self.test_seen_label = torch.from_numpy(label[test_seen_loc]).long()
        else:
            self.train_feature = torch.from_numpy(feature[trainval_loc]).float()
            self.train_label = torch.from_numpy(label[trainval_loc]).long()
            self.test_unseen_feature = torch.from_numpy(feature[test_unseen_loc]).float()
            self.test_unseen_label = torch.from_numpy(label[test_unseen_loc]).long()
            self.test_seen_feature = torch.from_numpy(feature[test_seen_loc]).float()
            self.test_seen_label = torch.from_numpy(label[test_seen_loc]).long()

        self.seenclasses = torch.from_numpy(np.unique(self.train_label.numpy()))
        self.unseenclasses = torch.from_numpy(np.unique(self.test_unseen_label.numpy()))

        if opt.zst:
            self.unseenclasses = self.unseenclasses + len(source_attributes)
            self.seenclasses = torch.arange(len(source_attributes))

        print(self.unseenclasses)
        print(self.seenclasses)

        self.nclass = len(self.seenclasses) + len(self.unseenclasses)
        self.ntrain = self.train_feature.size()[0]
        self.train_mapped_label = map_label(self.train_label, self.seenclasses)
