from Sampler import Sampler
import torch.nn.functional as F
import numpy as np
import torch

from utils import set_seed
from utils import plotSamples
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
import sys

from tqdm.auto import tqdm
import random as rand
from trainprior import get_model_with_ml

global DOMAIN_INCREMENTAL
DOMAIN_INCREMENTAL = False

global NUM_TASKS
NUM_TASKS = 5

global NUM_EXAMPLES
NUM_EXAMPLES = 15

def evaluate_with_ml(model, params_train, dataset_train, params_test, dataset_test, verbose = False):
    model = get_model_with_ml(model, params_train, dataset_train)
    return evaluate(params_test, model, dataset_test, verbose = verbose)

def get_embeded_dataset(model_base, dataset_test):
    embeded_data = []
    embeded_labels = []
    iterator = torch.utils.data.DataLoader(dataset_test, batch_size=128, shuffle=False,num_workers=0,)
    backbone = model_base.back_bone
    backbone = backbone.eval()
    device = next(model_base.parameters()).device
    for x,y in tqdm (iterator):
        with torch.no_grad():
            embeded_data.append (backbone(x.to(device)).cpu())

        embeded_labels.append(y)
        del x, y
        torch.cuda.empty_cache()
    embeded_data = torch.cat(embeded_data)
    embeded_labels = torch.cat(embeded_labels)

    elements_per_class = (embeded_labels == embeded_labels[0]).sum().item()
    active_classes = embeded_labels.unique_consecutive().tolist()
    # print (active_classes)
    num_classes = len (active_classes)
    

    embeded_data   = embeded_data.reshape  ([num_classes,elements_per_class,-1])
    embeded_labels = embeded_labels.reshape([num_classes,elements_per_class])
    
    embeded_data_dict = {}
    for i, label in enumerate (active_classes):
        embeded_data_dict[label] = embeded_data[i]
    
    class EmbededDataset(torch.utils.data.Dataset):
        embeded_data_dict =   []

        def __init__(self, active_classes, elements_per_class):
            self.elements_per_class = elements_per_class
            self.active_classes = active_classes
            
        def __len__(self):
            return len (self.active_classes) * self.elements_per_class

        def __getitem__(self, index):

            outer_index = int(index/self.elements_per_class)
            inner_index = index%self.elements_per_class

            label = self.active_classes[outer_index]
            return EmbededDataset.embeded_data_dict[label][inner_index], label
    
    EmbededDataset.embeded_data_dict   = embeded_data_dict

    return EmbededDataset(active_classes, elements_per_class)


def evaluate_splitmnist(params, model, way_keys = None, IID_epochs = 0, verbose = False):
    assert params.mode == 'test' or params.mode == 'val'
    assert not ((way_keys != None) and (IID_epochs != 0)), "Test can be either continual or IID, not both"
    
    device = next(model.parameters()).device
    
    with torch.no_grad():
        set_seed(params.seed)
        model.eval()

        import torchvision
        from torchvision.transforms import Compose, Resize, Lambda, ToTensor, Normalize
        print("Preparing Split-MNIST...")
        norm_params = {'mean': [0.1307], 'std': [0.3081]}
        compat_shape = [1, 84, 84]
        # no normalization is better
        # mnist_transform = Compose(
        #     [Resize(84), ToTensor(), Normalize(**norm_params)])
        mnist_transform = Compose(
            [Resize(84), ToTensor()])

        extra_dataset = torchvision.datasets.MNIST(
            download=True, root='./data', train=True)

        idx = np.arange(extra_dataset.__len__())
        train_indices= idx[:-5000]

        extra_dataset.targets = extra_dataset.targets[train_indices]
        extra_dataset.data = extra_dataset.data[train_indices]

        from torch.utils.data import Dataset
        class TransformedDataset(Dataset):
            def __init__(self, dataset, transform):
                data_list = []
                targets_list = []
                self.transform = transform

                for index in range(len(dataset)):
                    raw_data, _ = dataset[index]
                    label = dataset.targets[index]
                    transformed_data = self.transform(raw_data)
                    data_list.append(transformed_data)
                    if isinstance(label, int):
                        label = torch.tensor(label)
                    targets_list.append(label)
                self.data = torch.stack(data_list, dim=0)
                self.targets = torch.stack(targets_list, dim=0)

            def __len__(self):
                return self.data

        extra_dataset = TransformedDataset(extra_dataset, mnist_transform)

        # test loader
        split_mnist_test_loaders = {}

        for split_id in range(NUM_TASKS):  # 5 tasks
            # test set
            test_set = torchvision.datasets.MNIST(
                root='./data', train=False, transform=mnist_transform, download=True)
            idx_0 = test_set.targets == (split_id * 2)
            idx_1 = test_set.targets == (split_id * 2+1)
            idx_0_np = idx_0.numpy()
            idx_1_np = idx_1.numpy()
            idx_np = np.logical_or(idx_0_np, idx_1_np).astype(int)
            idx_np = np.argwhere(np.asarray(idx_np))
            idx = torch.from_numpy(idx_np).view(-1)
            # TODO this is domain incremental
            if DOMAIN_INCREMENTAL:
                test_set.targets = test_set.targets[idx] - split_id * 2
            else:
                test_set.targets = test_set.targets[idx]
                
            # test_set.targets = test_set.targets[idx]
            test_set.data = test_set.data[idx]

            extra_test_loader = torch.utils.data.DataLoader(
                dataset=test_set, batch_size=128, shuffle=False,
                pin_memory=True, num_workers=2, drop_last=True)

            split_mnist_test_loaders[split_id] = extra_test_loader

        if IID_epochs > 0 :
            print (f"Running IID test with {IID_epochs} epochs")
        
        # accuracies = []
        detailed_report = {}
        # run_id = 0  # loop over seed

        run_acc = []
        for run_id in range(10):

            # meta_test_steps == number of test tasks
            # for outer_loop in tqdm (range(params.meta_test_steps), desc="meta test outer loop"):
            for outer_loop in tqdm (range(0, NUM_TASKS), desc="meta test outer loop"):
                print(f'========   {outer_loop}  =======')
                accuracies = []
                # detailed_report[outer_loop] = {}

                set_seed(params.seed + run_id + 42)

                k_shot_train = NUM_EXAMPLES
                compat_shape = [1, 84, 84]

                # for split_id in range(NUM_TASKS):
                split_id = outer_loop
                extra_train_data = []
                extra_train_labels = []

                for class_id in range(split_id * 2, split_id * 2 + 2):
                    indices = extra_dataset.targets == class_id
                    extra_train_data.append(extra_dataset.data[indices][k_shot_train*run_id:k_shot_train*(run_id+1)].to(device))
                    extra_train_labels.append(extra_dataset.targets[indices][k_shot_train*run_id:k_shot_train*(run_id+1)].to(device))

                # class appears nth time only once all classes were seen for n-1 times for all n
                # i.e. classes appear shot-wise like 0, 1, 2, ..., 8, 9, 1, 2, ...
                extra_train_data = torch.stack(extra_train_data, dim=1)
                xs = extra_train_data.reshape(2 * k_shot_train, *compat_shape)

                extra_train_labels = torch.stack(extra_train_labels, dim=1)
                # TODO this is domain incremental
                if DOMAIN_INCREMENTAL:
                    ys = extra_train_labels.reshape(2 * k_shot_train) - split_id * 2
                else:
                    ys = extra_train_labels.reshape(2 * k_shot_train)

                # apply embedding:
                xs = model.back_bone(xs)

                # print(f'shape xs {xs.shape}')
                # print(f'shape ys {ys.shape}')
                # print(f'shape xq {xq.shape}')
                # print(f'shape yq {yq.shape}')

                # print(f'shape labels {labels.shape}')
                # print(f'targets labels {targets.shape}')
                # print(f'labels {torch.sort(labels)}')

                # xs,ys,xq,yq = sampler.getNext()
                if DOMAIN_INCREMENTAL:
                    labels = ys.unique()
                    task_labels = labels
                else:
                    labels = torch.arange(
                        0, NUM_TASKS*2, dtype=int, device=ys.device)
                    task_labels = labels[outer_loop*2:(outer_loop+1)*2]
                targets = torch.zeros_like(ys)
                supports = []
                if outer_loop == 0:
                    print('First task, reset model')
                    model.reset_episode(labels.to(device))
                    print(model.prototypes)

                for y in range(0, 2):
                    label = task_labels[y]
                    if not DOMAIN_INCREMENTAL:
                        y += outer_loop*2
                    # label = y
                    selector = ys == label
                    
                    targets[selector] = y
                    supports.append(xs[selector])
                    
                    # if IID_epochs == 0 :
                    # if outer_loop == 0:
                    model.update_prototype(xs[ys==label].to(device), y, embed=False)
                    torch.cuda.empty_cache()

                for task_id in range(0, outer_loop+1):
                
                    acc = 0.0
                    count = 0

                    for j, batch in enumerate(split_mnist_test_loaders[task_id]):
                        xq = batch[0].to(device)
                        yq = batch[1].to(device)

                        xq = model.back_bone(xq)
                        # print(labels.shape)
                        # import sys; sys.exit(0)

                        for y in range(0, 2):
                            if DOMAIN_INCREMENTAL:
                                label = labels[y]
                            else:
                                label = labels[task_id*2 + y]
                                y += task_id*2

                            selector = yq==label
                            scores = model(xq[selector].to(device),  embed=False)
                            torch.cuda.empty_cache()
                            pred = scores.argmax(dim=-1)

                            acc += (pred==y).sum().item()
                            count += pred.shape[0]

                    acc = acc / count
                    print(f'task {task_id} acc: {acc * 100:.1f}')
                    accuracies.append(acc)
                
                print(f"[{outer_loop}/{params.meta_test_steps}]: this epoch acc:{acc}, mean acc:{np.mean(accuracies)}, std:{np.std(accuracies)}")
            if task_id == NUM_TASKS - 1:
                run_acc.append(np.mean(accuracies))
        print(f"Mean acc:{np.mean(run_acc)}, std:{np.std(run_acc)}")
        print(run_acc)

        if way_keys is not None:
            detailed_report[outer_loop] = evaluate_continual(logits = predicted_logits, targets = targets, ways = way_keys)
        if way_keys is None:
            return np.mean(run_acc), np.std(run_acc)
        else:
            return detailed_report


def evaluate_continual(logits, targets, ways):
    report = {}
    for way in ways:
        selector = (targets <= (way-1))
        report[way] = (logits[selector,:way]).argmax(-1), targets[selector]
    return report

def report_plot_values(model, dataset_test, dataset_name):
    
    results = {}
    
    if dataset_name == "O":
        way_keys = {10:[10], 50:[50], 100:[100], 150:[150], 200:[200], 250:[250], 300:[300], 350:[350], 400:[400], 450:[450], 
                    500:[500], 550:[550] , 
                    600:[10, 50, 100, 150, 200, 250, 300, 350, 400, 450, 500, 550, 600]}

        shot_keys = [15]
    elif dataset_name == "M":
        shot_keys = [5, 30, 60, 90, 120, 150, 180, 210, 240]
        way_keys = {5:[5],10:[10],15:[15],20:[5,10,15,20]}
    elif dataset_name == "C":
        shot_keys = [30]
        way_keys = {3:[3],5:[5],10:[10],15:[15],20:[20],25:[25], 30:[3,5,10,15,20,25,30]}

    for way_key, way_list in way_keys.items():
        results[way_key] = {}
        for shot_key in tqdm(shot_keys):
            
            print ("running for shot" , shot_key, "and way", way_key, "way list:", way_list)
            
            if dataset_name == "O":
                params_test.query_num_train_tasks = params_test.support_num_train_tasks = way_key
                params_test.support_inner_step = shot_key
                params_test.query_train_inner_step = 5
                params_test.meta_test_steps = 100
            else:
                params_test.query_num_train_tasks = params_test.support_num_train_tasks = way_key
                params_test.support_inner_step = shot_key
                params_test.query_train_inner_step = 100
                params_test.meta_test_steps = 100
        
            results[way_key][shot_key] = evaluate(params_test, model, dataset_test, way_list)
            
        
    return results

if __name__=="__main__":
    dataset_name = sys.argv[1]
    check_point_path = sys.argv[2] + "check_point"
    
    if (dataset_name == "O"):
        from datasets.omniglot.TrainParams import MetaTrainParams
        from datasets.omniglot.TestParams  import MetaTestParams
    elif (dataset_name == "M"):
        from datasets.miniimagenet.TrainParams import MetaTrainParams
        from datasets.miniimagenet.TestParams  import MetaTestParams
    elif (dataset_name == "C"):
        from datasets.cifar100.TrainParams import MetaTrainParams
        from datasets.cifar100.TestParams  import MetaTestParams

    from dataset import Dataset

    params_train = MetaTrainParams()
    params_test = MetaTestParams()

    import torchvision
    from torchvision.transforms import Compose, Resize, Lambda, ToTensor, Normalize
    norm_params = {'mean': [0.1307], 'std': [0.3081]}

    mnist_transform = Compose([Resize(84), ToTensor()])
    # mnist_transform = Compose(
    #     [Resize(84), ToTensor(), Normalize(**norm_params)])
    dataset_train = torchvision.datasets.MNIST(
        download=True, root='./data', train=True, transform=mnist_transform)

    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

    model = params_train.modelClass(params_train).to(device)
    print ("loading model from " , check_point_path)
    model.back_bone.load_state_dict(torch.load(check_point_path))

    model = get_model_with_ml(model, params_train, dataset_train)

    acc_mean, acc_std = evaluate_splitmnist(params_test, model, verbose=True)
    # acc_mean, acc_std = evaluate(params_test, model, dataset_test, verbose=True)
    print (acc_mean, acc_std)
