import numpy as np
import torch
from torch.utils.data import DataLoader, TensorDataset


class DelayedDiscrimination():
    task_name = 'delayed_discrimination'

    def __init__(self, params):
        self.params = params
        self.low_t = params['low_ts']
        self.high_t = params['high_ts']
        self.max_delay = params['max_delay']

    @property
    def steps(self):
        return 40 + self.max_delay

    @property
    def pulse(self):
        return 5

    @property
    def name(self):
        return f"{self.task_name}_{self.steps}"

    def generate_trial(self, t1, t2, delay=0, steps=None):
        if steps is None:
            steps = self.steps
        dim_in = self.params['dim_input']
        dim_out = self.params['dim_output']
        x = np.zeros((steps, dim_in))
        y = np.zeros((steps, dim_out))

        if delay is None:
            delay = np.random.randint(0, self.max_delay)

        if not self.params['time_shuffled']:
            for ch in range(dim_in):
                x[5:5 + self.pulse, ch] = t1[ch]
                x[10 + delay:10 + delay + self.pulse, ch] = t2[ch]
        else:
            for ch in range(dim_in):
                x[5:5 + self.pulse, ch] = t1[ch]
                x[10:10 + self.pulse, ch] = t2[ch]

        if self.params['target'] == 'value':
            diff = np.array(t1) - np.array(t2)
            target = np.concatenate([np.sign(diff), np.abs(diff)])
        else:
            target = np.sign(np.array(t1) - np.array(t2))

        y[20 + delay:20 + delay + self.pulse, :] = target
        return x, y

    def generate_t1_t2(self):
        low, high = self.low_t, self.high_t
        t1_list, t2_list = [], []
        for _ in range(self.params['dim_input']):
            t1 = np.random.randint(low, high)
            t2 = np.random.randint(low, high)
            while t1 == t2:
                t2 = np.random.randint(low, high)
            t1_list.append(t1)
            t2_list.append(t2)
        return t1_list, t2_list

    def generate_train_data(self):
        N = self.params['num_samples']
        per_delay = N // self.max_delay
        x = np.zeros((N, self.steps, self.params['dim_input']))
        y = np.zeros((N, self.steps, self.params['dim_output']))
        for i in range(N):
            t1, t2 = self.generate_t1_t2()
            x[i], y[i] = self.generate_trial(t1, t2, delay=i // per_delay)
        return x, y

    def get_train_loader(self):
        x_np, y_np = self.generate_train_data()
        x = torch.from_numpy(x_np).to(self.params['device']).float()
        y = torch.from_numpy(y_np).to(self.params['device']).float()
        dataset = TensorDataset(x, y)
        return DataLoader(dataset,
                          batch_size=self.params['n_batch'],
                          shuffle=True)
