import numpy as np
import torch
import torch.nn as nn

from collections.abc import Sequence
from seed import set_seed

from nn_utils import compute_scores

TEST_SIZE_CLASSIFICATION = 10000
TEST_SIZE_OOD = int(0.2 * TEST_SIZE_CLASSIFICATION)

ood_mixtures = {
    'cifar10': {
        'far_ood': ['mnist', 'svhn', 'texture'],
        'near_ood': ['cifar100', 'tin', 'places365']
    },
    'cifar100': {
        'far_ood': ['mnist', 'svhn', 'texture'],
        'near_ood': ['cifar10', 'tin', 'places365']
    },
}

###### Base Data Stream ######
class DataStream(Sequence):
    def __init__(self, config, seed):
        # set seed
        set_seed(seed)

        # config
        self.config = config
        self.name0 = config.ood_name
        self.name1 = config.id_name

        self.gamma = config.gamma
        self.num0 = int(config.num * self.gamma)
        self.num1 = config.num - self.num0

        # entire data sequence
        self.z_stream = []
        self.s_stream = []
        self.y_stream = []

        # current position
        self.curr_pos = 0

    def __len__(self):
        return len(self.z_stream)

    def __getitem__(self,i):
        return [self.z_stream[i], self.s_stream[i], self.y_stream[i]]

    def get_z1_train(self):
        return self.z1_train

    def inc_pos(self):
        self.curr_pos += 1

    def set_pos(self, i):
        self.curr_pos = i

    def update_stream(self, g, device):
        # if stream is exhuasted, no scores to change
        if self.curr_pos == -1:
            return

        self.g = None
        assert isinstance(g, nn.Module), 'Error: new score function is invalid'

        # compute the new scores for entire stream after current position w.r.t. the new score function g
        new_s_stream = compute_scores(g, self.z_stream[self.curr_pos:], device)

        # replace the old scoreswith with the new scores 
        self.s_stream[self.curr_pos:] = new_s_stream

        # update scores w.r.t. the new g
        self.s0 = compute_scores(g, self.z0, device)
        self.s1 = compute_scores(g, self.z1, device)
        self.s0_test = compute_scores(g, self.z0_test, device)
        self.s1_test = compute_scores(g, self.z1_test, device)

    def fpr(self, lam):
        return 1-(self.s0 < lam).mean()

    def tpr(self, lam):
        return 1-(self.s1 < lam).mean()

    def tpr_test(self, alpha):
        lam = np.quantile(self.s0_test, 1-alpha)
        return 1-(self.s1_test < lam).mean()

    def min_score(self):
        return np.min(self.s0)

    def ood_quantile_score(self, q):
        return np.quantile(self.s0, q, axis=0)

    def id_quantile_score(self, q):
        return np.quantile(self.s1, q, axis=0)

    def max_score(self):
        return np.max(self.s1)

###### Data Stream for class of two-layer neural networks ######
class DataStream1(DataStream):
    def __init__(self, config, seed):
        super().__init__(config, seed)

        # create streams
        self.g = config.init_g
        self.create_stream()

    def create_stream(self):
        # load OOD and split
        if self.name0 == 'far_ood' or self.name0 == 'near_ood':
            z0_lst, s0_lst = [], []
            for ood_name in ood_mixtures[self.name1][self.name0]:
                z0_lst.append(np.load(f'../data/features/{self.config.model}/{ood_name}.npz')['arr_0'])
                s0_lst.append(np.load(f'../data/norm_scores/{self.config.model}/{self.g}/{ood_name}.npz')['arr_0'])
            self.z0 = np.concatenate(z0_lst, axis=0)
            self.s0 = np.concatenate(s0_lst, axis=0)
        else:
            self.z0 = np.load(f'../data/features/{self.config.model}/{self.name0}.npz')['arr_0']
            self.s0 =  np.load(f'../data/norm_scores/{self.config.model}/{self.g}/{self.name0}.npz')['arr_0']
        assert len(self.z0) == len(self.s0), 'Error: out-of-distribution sample sizes do not match'
 
        test_indices0, indices0 = np.split(np.random.permutation(len(self.z0)), [int(TEST_SIZE_OOD * self.gamma / (1-self.gamma))])
        z0, self.z0_test = self.z0[indices0], self.z0[test_indices0]
        s0, self.s0_test = self.s0[indices0], self.s0[test_indices0]
        
        # load ID and split it into train and test (samples not used for trainging the classifier model)
        self.z1 = np.load(f'../data/features/{self.config.model}/{self.name1}.npz')['arr_0'][:TEST_SIZE_CLASSIFICATION]
        self.s1 =  np.load(f'../data/norm_scores/{self.config.model}/{self.g}/{self.name1}.npz')['arr_0'][:TEST_SIZE_CLASSIFICATION]
        assert len(self.z1) == len(self.s1), 'Error: in-distribution sample sizes do not match'
        
        test_indices1, indices1 = np.split(np.random.permutation(len(self.z1)), [TEST_SIZE_OOD])
        self.z1_train, self.z1_test = self.z1[indices1], self.z1[test_indices1]
        s1, self.s1_test = self.s1[indices1], self.s1[test_indices1]

        # sample OOD/ID from data with replacement
        ind0 = np.random.choice(len(z0), size=self.num0, replace=True)
        ind1 = np.random.choice(len(self.z1_train), size=self.num1, replace=True)
        z0, s0 = z0[ind0], s0[ind0]
        z1, s1 = self.z1_train[ind1], s1[ind1]
        y0, y1 = np.zeros(len(z0)), np.ones(len(z1))

        # combine ODD and ID and create sequeneces (and shuffle)
        ind = np.random.permutation(len(z0) + len(z1))
        self.z_stream = np.concatenate((z0,z1), axis=0)[ind]
        self.s_stream = np.concatenate((s0,s1), axis=0)[ind]
        self.y_stream = np.concatenate((y0,y1), axis=0)[ind]

###### Data stream for class of linear combination of different scoring functions (under development) ######
# class DataStream2(DataStream):
#     def __init__(self, config):
#         super().__init__(config)

#         # create streams
#         self.g = None
#         self.w = config.w
#         self.g_lst = config.g_lst
#         self.create_stream()

#     def create_stream(self):
#         # load OOD and split
#         self.x0 = np.load(f'../data/datasets/{self.name0}.npz')['arr_0']

#         self.z0 = []
#         for g in self.g_lst:
#             self.z0.append(np.load(f'../data/norm_scores/{self.config.model}/{g}/{self.name0}.npz')['arr_0'])
#         self.z0 = np.array(list(zip(*self.z0)))
#         self.s0 = np.array([np.dot(self.w, v0) for v0 in self.z0])
#         assert len(self.x0) == len(self.z0) == len(self.s0), 'Error: out-of-distribution sample sizes do not match'

#         test_indices0, indices0 = np.split(np.random.permutation(len(self.x0)), [int(TEST_SIZE_OOD * self.gamma / (1-self.gamma))])
#         x0, self.x0_test = self.x0[indices0], self.x0[test_indices0]
#         z0, self.z0_test = self.z0[indices0], self.z0[test_indices0]
#         s0, self.s0_test = self.s0[indices0], self.s0[test_indices0]
        
#         # load ID and split it into train and test (samples not used for trainging the classifier model)
#         self.x1 = np.load(f'../data/datasets/{self.name1}.npz')['arr_0'][:TEST_SIZE_CLASSIFICATION]

#         self.z1 = []
#         for g in self.g_lst:
#             self.z1.append(np.load(f'../data/norm_scores/{self.config.model}/{g}/{self.name1}.npz')['arr_0'][:TEST_SIZE_CLASSIFICATION])
#         self.z1 = np.array(list(zip(*self.z1)))
#         self.s1 = np.array([np.dot(self.w, v1) for v1 in self.z1])
#         assert len(self.x1) == len(self.z1) == len(self.s1), 'Error: in-distribution sample sizes do not match'
        
#         test_indices1, indices1 = np.split(np.random.permutation(len(self.x1)), [TEST_SIZE_OOD])
#         x1, self.x1_test = self.x1[indices1], self.x1[test_indices1]
#         self.z1_train, self.z1_test = self.z1[indices1], self.z1[test_indices1]
#         s1, self.s1_test = self.s1[indices1], self.s1[test_indices1]

#         # sample OOD/ID from data with replacement
#         ind0 = np.random.choice(len(x0), size=self.num0, replace=True)
#         ind1 = np.random.choice(len(x1), size=self.num1, replace=True)
#         x0, z0, s0 = x0[ind0], z0[ind0], s0[ind0]
#         x1, z1, s1 = x1[ind1], self.z1_train[ind1], s1[ind1]
#         y0, y1 = np.zeros(len(x0)), np.ones(len(x1))

#         # combine ODD and ID and create sequeneces (and shuffle)
#         ind = np.random.permutation(len(x0) + len(x1))
#         self.x_stream = np.concatenate((x0,x1), axis=0)[ind]
#         self.z_stream = np.concatenate((z0,z1), axis=0)[ind]
#         self.s_stream = np.concatenate((s0,s1), axis=0)[ind]
#         self.y_stream = np.concatenate((y0,y1), axis=0)[ind]

###### Data Stream for Distribution Shifts ######
class MergedStream(Sequence):
    def __init__(self, streams):
        # all streams
        self.streams = streams 
        self.num_streams = len(streams)
        self.len_streams = [len(s) for s in streams]
        self.num = np.sum(self.len_streams)

        # current position in each stream
        self.loc_pos_lst = [0 for _ in range(self.num_streams)]

        # current stream ID
        self.id = 0

        # current global position
        self.curr_global_pos = 0
    
    def __len__(self):
        return self.num

    def __getitem__(self, i):
        stream_id, loc_pos = self.convert_to_loc_pos(i)
        return self.streams[stream_id][loc_pos]

    def get_z1_train(self):
        return self.streams[0].get_z1_train()

    def convert_to_loc_pos(self, i):
        # convert the global index `i` into a specific stream and local position within that stream
        loc_pos = i
        for stream_id, stream_len in enumerate(self.len_streams):
            if loc_pos < stream_len:
                return stream_id, loc_pos
            loc_pos -= stream_len
        raise IndexError("Index out of range")

    def inc_pos(self):
        while self.id < self.num_streams-1:
            # if the current stream is not exhausted
            if self.loc_pos_lst[self.id] != -1:
                self.loc_pos_lst[self.id] += 1
                self.curr_global_pos += 1

                # if current stream is exhausted, mark it as -1
                if self.loc_pos_lst[self.id] >= self.len_streams[self.id]:
                    self.loc_pos_lst[self.id] = -1 

                while self.id < self.num_streams and self.loc_pos_lst[self.id] == -1:
                    self.id += 1
                
                break
            else:
                # move to the next stream if current is exhausted
                self.id += 1

        # set the local coordinates in all streams
        for stream, loc_pos in zip(self.streams, self.loc_pos_lst):
            stream.set_pos(loc_pos)
    
    def update_stream(self, g, device):
        # g can take either form (Method 1 or Method 2)
        for stream in self.streams:
            stream.update_stream(g, device)

    def fpr(self, lam):
        return 1-(self.streams[self.id].s0 < lam).mean()

    def tpr(self, lam):
        return 1-(self.streams[self.id].s1 < lam).mean()

    def tpr_test(self, alpha):
        lam = np.quantile(self.streams[self.id].s0_test, 1-alpha)
        return 1-(self.streams[self.id].s1_test < lam).mean()

    def min_score(self):
        return np.min([stream.min_score() for stream in self.streams])

    def ood_quantile_score(self, q):
        return np.quantile(np.concatenate([stream.s0 for stream in self.streams]), q, axis=0)

    def id_quantile_score(self, q):
        return np.quantile(np.concatenate([stream.s1 for stream in self.streams]), q, axis=0)

    def max_score(self):
        return np.max([stream.max_score() for stream in self.streams])
        