import torch, argparse, sys, json, os, warnings, re
from torch.utils.data import Dataset, ConcatDataset

sys.path.append('../src')
os.chdir("../src")
warnings.simplefilter(action='ignore', category=UserWarning)

from utils import *
from models.basic import *
from models.resnets import *
from models.vit import *

from detector_utils.dataset import get_split

class PhylogenyDataset(Dataset):
    def __init__(self, pmodel, cmodel):
        super().__init__()
        self.create_dataset(pmodel, cmodel)

    def get_real_parents(self, listp, listc):
        listp = [p.rsplit("/",1)[0] for p in listp]
        true_label = []
        for c in listc:
            flag = True
            for index, p in enumerate(listp):
                if p in c:
                    true_label.append(index)
                    flag = False
                    break
            if flag:
                raise RuntimeError("Not True Label")
        return torch.tensor(true_label).long()

    def get_modelname(self, pmodel, cmodel):
        if "FC3D" in pmodel:
            assert "FC3D" in cmodel
            return "fc3d"
        elif "FC" in pmodel:
            assert "FC" in cmodel
            return "fc"
        elif "CONV3D" in pmodel:
            assert "CONV3D" in cmodel
            return "conv3d"
        elif "CONV" in pmodel:
            assert "CONV" in cmodel
            return "conv"
        elif "RESNET18" in pmodel:
            assert "RESNET18" in cmodel
            return "resnet18"
        elif "VITT" in pmodel:
            assert "VITT" in cmodel
            return "vit"
        else:
            raise RuntimeError("Model Name Not Found")

    def create_dataset(self, pmodel, cmodel):
        self.pmodel, self.cmodel = pmodel, cmodel
        self.modelname = self.get_modelname(pmodel, cmodel)

        json_folder = "./Models/config/Vision_Classification"

        self.pjson = os.path.join(json_folder, pmodel + ".json")
        self.cjson = os.path.join(json_folder, cmodel + ".json")

        self.pmodelpaths = json.load(open(self.pjson,"r"))
        self.cmodelpaths = json.load(open(self.cjson,"r"))

        self.num_p = len(self.pmodelpaths)
        self.num_c = len(self.cmodelpaths)

        self.weight_p = self.get_weights(self.pmodelpaths, self.modelname)
        self.weight_c = self.get_weights(self.cmodelpaths, self.modelname)

        self.weight_p = self.weight_p.detach()
        self.weight_c = self.weight_c.detach()

        self.weight_p.requires_grad = False
        self.weight_c.requires_grad = False

        self.label_p = torch.arange(self.num_p)
        self.label_c = self.get_real_parents(self.pmodelpaths, self.cmodelpaths)

    @property
    def complete_dataset(self):
        weight_p, weight_c = self.weight_p.flatten(1), self.weight_c.flatten(1)
        labels = torch.cat((self.label_p, self.label_c), dim =0)
        if weight_p.shape[1] > 1024:
            return weight_p[:,:1024], weight_c[:,:1024], labels
        else:
            return weight_p, weight_c, labels


    def get_weights(self, modelpaths, modelname):
        model_weight_name_map = {
            "fc":"fc1.weight",
            "fc3d":"fc1.weight",
            "conv":"conv1.weight",
            "conv3d":"conv1.weight",
            "resnet18":"layer1.0.conv1.weight",
            "resnet18-EMNIST-Letters":"conv1.weight",
            "resnet18-FMNIST":"conv1.weight",
            "vit":"blocks.2.mlp.fc2.weight"
        }
        weights = []
        weight_name = model_weight_name_map[modelname]
        for modelpath in modelpaths:
            model = torch.load(modelpath, map_location=torch.device('cpu'))
            if isinstance(model, dict):
                pass
            elif isinstance(model, nn.Module):
                model = model.state_dict()
            weights.append(model[weight_name])
        return torch.stack(weights, dim = 0)

    def __getitem__(self, index):
        pass
        return feature, weight, label

    def __len__(self):
        return self.num_c

class PhylogenyDetectorDataset(PhylogenyDataset):
    def __init__(self, pmodel, cmodel, child_per_parent = 3, parent_per_sample = 3):
        super().__init__(pmodel, cmodel)
        self.child_per_parent = child_per_parent
        self.parent_per_sample = parent_per_sample

        self.get_sub_cluster()
        self.expansion_setup()

    def expansion_setup(self):
        self.expansion_factor = 1
        self.expansion_index = [[0,32,0,32]] 

    def get_sub_cluster(self):
        p_c_dict = {}
        finished_dict = {}
        for p_index in range(self.num_p):
            corr_c_index = torch.where(self.label_c == p_index)[0]
            corr_c_index = corr_c_index[:len(corr_c_index) - len(corr_c_index)%self.child_per_parent]
            corr_c_index = torch.sort(corr_c_index)[0]
            corr_c_index = corr_c_index.view(-1,self.child_per_parent)
            p_c_dict[p_index] = corr_c_index
            finished_dict[p_index] = {"cur":0,"total":corr_c_index.shape[0]}

        DatasetIndex_PC_map = []

        while True:
            PC_pairs = []
            success = False
            for p_index in range(self.num_p):
                if finished_dict[p_index]["cur"]< finished_dict[p_index]["total"]:
                    PC_pairs.append((p_index, finished_dict[p_index]["cur"]))
                    finished_dict[p_index]["cur"] = finished_dict[p_index]["cur"] + 1
                if len(PC_pairs) >= self.parent_per_sample:
                    success = True
                    break
            if success:
                DatasetIndex_PC_map.append(PC_pairs)
            else:
                break
        
        self.p_c_dict = p_c_dict
        self.DatasetIndex_PC_map = DatasetIndex_PC_map
    
    def __getitem__(self, index):

        expansion_index = self.expansion_index[index % self.expansion_factor]
        index = index // self.expansion_factor

        PC_pairs = self.DatasetIndex_PC_map[index]
        p_weight = []
        c_weight = []
        p_label = torch.arange(self.parent_per_sample)
        c_label = torch.repeat_interleave(p_label, self.child_per_parent)
        for p_index, c_index in PC_pairs:
            p_weight.append(self.weight_p[p_index])
            c_weight.append(self.weight_c[self.p_c_dict[p_index][c_index]])
        p_weight = torch.stack(p_weight, dim = 0)
        c_weight = torch.cat(c_weight, dim = 0)

        weight = torch.cat((p_weight, c_weight), dim = 0)
        label = torch.cat((p_label, c_label), dim = 0)

        if self.modelname in ["fc","fc3d"]: # 1024x784, 1024x2352
            weight = weight
        elif self.modelname == "conv": # 16x1x3x3
            weight = weight.view(-1,16,9)
        elif self.modelname == "conv3d": # 16x3x3x3
            weight = weight.view(-1,16,27)
        elif self.modelname == "resnet18": # 16x3x3x3
            weight = weight.view(-1,64*3,64*3)
        elif self.modelname == "vit": 
            weight = weight
        else:
            raise RuntimeError("Model Name Not Found")

        weight = weight[:,expansion_index[0]:expansion_index[1],expansion_index[2]:expansion_index[3]]

        return weight, label

    def __len__(self):
        return len(self.DatasetIndex_PC_map) * self.expansion_factor

def get_phylogeny_loader(args, shuffle = True):
    train_kwargs = {'batch_size': args.batch_size, 'num_workers': 4,'shuffle': shuffle}
    val_kwargs = {'batch_size': args.test_batch_size, 'num_workers': 4,'shuffle': False}
    test_kwargs = {'batch_size': args.test_batch_size, 'num_workers': 4,'shuffle': False}

    trainsets = []
    valsets = []
    testsets = []
    fulldatasets = []
    for pmodel,cmodel in zip(args.pmodels, args.cmodels):
        dataset = PhylogenyDetectorDataset(pmodel, cmodel)

        train_set, val_set, test_set = get_split(dataset)

        trainsets.append(train_set)
        valsets.append(val_set)
        testsets.append(test_set)
        fulldatasets.append(dataset)
    
    train_loader = torch.utils.data.DataLoader(
        ConcatDataset(trainsets+valsets),
        **train_kwargs)
    val_loader = torch.utils.data.DataLoader(
        ConcatDataset(valsets),
        **val_kwargs)
    test_loaders = [
        torch.utils.data.DataLoader(test_set, **test_kwargs) for test_set in testsets
    ]
    full_loaders = [torch.utils.data.DataLoader(
        fulldataset,
        **val_kwargs) for fulldataset in fulldatasets]

    return train_loader, val_loader, test_loaders, full_loaders, dataset.parent_per_sample