import gym
import torch.optim as optim
from datetime import datetime
from tqdm import tqdm
import random
import time
import utils
from loss_functions import loss_fn_ed
import numpy as np
import torch
import torch.nn.functional as F

class TransformationCodingWrapper(gym.Wrapper):
    def __init__(self, env, args, enc, verbose=True):
        gym.Wrapper.__init__(self, env)
        self.args = args
        self.enc = enc
        self.verbose = verbose

        self.opt = optim.Adam(
            enc.parameters(),
            lr=args.learning_rate,
            weight_decay=args.weight_decay
        )
        self.scheduler = optim.lr_scheduler.StepLR(
            self.opt, step_size=9999, gamma=0.5
        )
        # TODO: step_size and gamma should come from the arguments

        self.stats = {
            'avg_loss_list': [],
            'avg_loss_equiv_list': [],
            'avg_loss_barrier_list': []
        }
        self.local_step = 0
        self.global_step = 0
        self.epoch_cnt = 0
        self.warmup_cnt = 0

    def reset(self, **kwargs):
        return self.env.reset(**kwargs)

    def step(self, a):
        s_, reward, done, info = self.env.step(a)

        if self.warmup_cnt == 0 and self.args.warmup_steps > 0:
            print('[%s] Started warmup' % datetime.now())
            self.progress = tqdm(range(self.args.warmup_steps), desc='Warmup')

        if self.warmup_cnt < self.args.warmup_steps:
            self.warmup_cnt += 1
            self.progress.update(1)
            if self.warmup_cnt == self.args.warmup_steps:
                self.progress.close()
                print('[%s] Finished warmup' % datetime.now())
            info['epoch_cnt'] = self.epoch_cnt
            info['avg_loss_list'] = self.stats['avg_loss_list']
            return s_, reward, done, info

        if not (random.random() <= self.args.update_prob) or self.epoch_cnt == self.args.num_epochs:
            info['epoch_cnt'] = self.epoch_cnt
            info['avg_loss_list'] = self.stats['avg_loss_list']
            return s_, reward, done, info

        if self.local_step == 0:
            self.stats['loss_list'] = []
            self.stats['loss_equiv_list'] = []
            self.stats['loss_barrier_list'] = []
            self.time_start = time.time()
            self.progress = tqdm(
                range(self.args.steps_per_epoch),
                desc='Loss: None | Loss Equiv: None | '
                     'Loss Barrier: None | L2 Weights: %12g | '
                     'L2 Grads: 0' % (
                        utils.get_weights_norm(self.enc.parameters(), norm_type=2.0)
                        ), total=self.args.steps_per_epoch, position=0, leave=True
            )

        self._update(self.enc, info['buffers'])
        self.progress.update(1)
        self.progress.set_description(
            'Loss: %12g | Loss Equiv: %12g | Loss Barrier: %12g | '
            'L2 Weights: %12g | L2 Grads: %12g' % (
                self.stats['loss_list'][-1],
                self.stats['loss_equiv_list'][-1],
                self.stats['loss_barrier_list'][-1],
                utils.get_weights_norm(self.enc.parameters(), norm_type=2.0),
                utils.get_grads_norm(self.enc.parameters(), norm_type=2.0)
            ))

        if self.local_step == self.args.steps_per_epoch:
            self.epoch_cnt += 1
            self.progress.close()
            self.local_step = 0
            time_end = time.time()

            self.stats['avg_loss'] = np.mean(self.stats['loss_list'])
            self.stats['avg_loss_equiv'] = np.mean(self.stats['loss_equiv_list'])
            self.stats['avg_loss_barrier'] = np.mean(self.stats['loss_barrier_list'])
            self.stats['avg_loss_list'].append(self.stats['avg_loss'])
            self.stats['avg_loss_equiv_list'].append(self.stats['avg_loss_equiv'])
            self.stats['avg_loss_barrier_list'].append(self.stats['avg_loss_barrier'])
            if self.verbose:
                print('\nEpoch %3d | Loss: %12g | Loss Equiv: %12g | '
                      'Loss Barrier: %12g | Time: %6.1f sec' % (
                    self.epoch_cnt, self.stats['avg_loss'],
                    self.stats['avg_loss_equiv'], self.stats['avg_loss_barrier'],
                    time_end - self.time_start)
                      )

            # scheduler.step(avg_loss)
            self.scheduler.step()

        info['epoch_cnt'] = self.epoch_cnt
        info['avg_loss_list'] = self.stats['avg_loss_list']
        return s_, reward, done, info

    def _update(self, enc, buffers):
        self.enc.train()
        self.opt.zero_grad()
        loss_equiv = 0
        loss_barrier = 0

        if self.args.num_actions_train < self.env.action_space.n:
            action_list = random.sample(range(self.env.action_space.n),
                                        k=self.args.num_actions_train)
        else:
            action_list = range(self.action_space.n)

        for a in action_list:
            batch = np.array(random.sample(buffers[a].tolist(), self.args.batch_size))
            x_list = utils.unstack(batch, axis=1)
            x_list = [x.astype(np.float32)/255.0 for x in x_list]
            loss_equiv_, loss_barrier_ = loss_fn_ed(
                enc, x_list, self.args.barrier_type, self.args.hinge_thresh,
                self.args.cosine_sim, self.args.conformal_map
            )
            loss_equiv += loss_equiv_
            loss_barrier += loss_barrier_
        loss_barrier /= self.env.action_space.n
        loss_equiv /= self.env.action_space.n
        loss = loss_equiv + self.args.barrier_coef * loss_barrier
        loss.backward()
        self.opt.step()
        self.stats['loss_list'].append(loss.item())
        self.stats['loss_equiv_list'].append(loss_equiv.item())
        self.stats['loss_barrier_list'].append(self.args.barrier_coef * loss_barrier.item())
        self.local_step += 1
        self.global_step += 1

class AutoencoderWrapper(gym.Wrapper):
    def __init__(self, env, args, enc, dec, verbose=True):
        gym.Wrapper.__init__(self, env)
        self.args = args
        self.enc = enc
        self.dec = dec
        self.verbose = verbose

        self.opt = optim.Adam(
            list(enc.parameters()) + list(dec.parameters()),
            lr=args.learning_rate,
            weight_decay=args.weight_decay
        )
        self.scheduler = optim.lr_scheduler.StepLR(
            self.opt, step_size=9999, gamma=0.5
        )
        # TODO: step_size and gamma should come from the arguments

        self.stats = {
            'avg_loss_list': []
        }
        self.mse_loss = torch.nn.MSELoss(reduction='mean')
        self.local_step = 0
        self.global_step = 0
        self.epoch_cnt = 0
        self.warmup_cnt = 0

    def reset(self, **kwargs):
        return self.env.reset(**kwargs)

    def step(self, a):
        s_, reward, done, info = self.env.step(a)

        if self.warmup_cnt == 0 and self.args.warmup_steps > 0:
            print('[%s] Started warmup' % datetime.now())
            self.progress = tqdm(range(self.args.warmup_steps), desc='Warmup')

        if self.warmup_cnt < self.args.warmup_steps:
            self.warmup_cnt += 1
            self.progress.update(1)
            if self.warmup_cnt == self.args.warmup_steps:
                self.progress.close()
                print('[%s] Finished warmup' % datetime.now())
            info['epoch_cnt'] = self.epoch_cnt
            info['avg_loss_list'] = self.stats['avg_loss_list']
            return s_, reward, done, info

        if not (random.random() <= self.args.update_prob) or self.epoch_cnt == self.args.num_epochs:
            info['epoch_cnt'] = self.epoch_cnt
            info['avg_loss_list'] = self.stats['avg_loss_list']
            return s_, reward, done, info

        if self.local_step == 0:
            self.stats['loss_list'] = []
            self.time_start = time.time()
            self.progress = tqdm(
                range(self.args.steps_per_epoch),
                desc='Loss: None | L2 Weights: %12g | '
                     'L2 Grads: 0' % (
                        utils.get_weights_norm(self.enc.parameters(), norm_type=2.0)
                        ), total=self.args.steps_per_epoch, position=0, leave=True
            )

        self._update(info['buffers'])
        self.progress.update(1)
        self.progress.set_description(
            'Loss: %12g | L2 Weights: %12g | L2 Grads: %12g' % (
                self.stats['loss_list'][-1],
                utils.get_weights_norm(self.enc.parameters(), norm_type=2.0),
                utils.get_grads_norm(self.enc.parameters(), norm_type=2.0)
            ))

        if self.local_step == self.args.steps_per_epoch:
            self.epoch_cnt += 1
            self.progress.close()
            self.local_step = 0
            time_end = time.time()

            self.stats['avg_loss'] = np.mean(self.stats['loss_list'])
            self.stats['avg_loss_list'].append(self.stats['avg_loss'])
            if self.verbose:
                print('\nEpoch %3d | Loss: %12g | Time: %6.1f sec' % (
                    self.epoch_cnt, self.stats['avg_loss'],
                    time_end - self.time_start)
                      )

            # scheduler.step(avg_loss)
            self.scheduler.step()

        info['epoch_cnt'] = self.epoch_cnt
        info['avg_loss_list'] = self.stats['avg_loss_list']
        return s_, reward, done, info

    def _update(self, buffers):
        self.enc.train()
        self.opt.zero_grad()
        batch = np.array(random.sample(buffers.tolist(), self.args.batch_size))
        device = utils.get_device(self.enc)
        batch = torch.Tensor(batch).to(device)
        batch = batch.float() / 255.0
        batch_rec = self.dec(self.enc(batch))
        loss = self.mse_loss(batch_rec, batch)
        loss.backward()
        self.opt.step()
        self.stats['loss_list'].append(loss.item())
        self.local_step += 1
        self.global_step += 1


class ContrastiveLoss(torch.nn.Module):
    def __init__(self, batch_size, temperature=0.5):
        super().__init__()
        self.batch_size = batch_size
        self.register_buffer("temperature", torch.tensor(temperature))
        self.register_buffer("negatives_mask", (~torch.eye(batch_size * 2, batch_size * 2, dtype=bool)).float())

    def forward(self, emb_i, emb_j, device):
        """
        emb_i and emb_j are batches of embeddings, where corresponding indices are pairs
        z_i, z_j as per SimCLR paper
        """
        z_i = F.normalize(emb_i, dim=1)
        z_j = F.normalize(emb_j, dim=1)

        representations = torch.cat([z_i, z_j], dim=0)
        similarity_matrix = F.cosine_similarity(representations.unsqueeze(1), representations.unsqueeze(0), dim=2)

        sim_ij = torch.diag(similarity_matrix, self.batch_size)
        sim_ji = torch.diag(similarity_matrix, -self.batch_size)
        positives = torch.cat([sim_ij, sim_ji], dim=0)

        nominator = torch.exp(positives / self.temperature)
        denominator = self.negatives_mask.to(device) * torch.exp(similarity_matrix / self.temperature)

        loss_partial = -torch.log(nominator / torch.sum(denominator, dim=1))
        loss = torch.sum(loss_partial) / (2 * self.batch_size)
        return loss


class SimCLRWrapper(gym.Wrapper):
    def __init__(self, env, args, enc, verbose=True):
        gym.Wrapper.__init__(self, env)
        self.args = args
        self.enc = enc
        self.verbose = verbose

        self.opt = optim.Adam(
            list(enc.parameters()),
            lr=args.learning_rate,
            weight_decay=args.weight_decay
        )
        self.scheduler = optim.lr_scheduler.StepLR(
            self.opt, step_size=9999, gamma=0.5
        )
        # TODO: step_size and gamma should come from the arguments

        self.stats = {
            'avg_loss_list': []
        }

        self.const_loss = ContrastiveLoss(args.batch_size, args.temp)
        self.local_step = 0
        self.global_step = 0
        self.epoch_cnt = 0
        self.warmup_cnt = 0

    def reset(self, **kwargs):
        return self.env.reset(**kwargs)

    def step(self, a):
        s_, reward, done, info = self.env.step(a)

        if self.warmup_cnt == 0 and self.args.warmup_steps > 0:
            print('[%s] Started warmup' % datetime.now())
            self.progress = tqdm(range(self.args.warmup_steps), desc='Warmup')

        if self.warmup_cnt < self.args.warmup_steps:
            self.warmup_cnt += 1
            self.progress.update(1)
            if self.warmup_cnt == self.args.warmup_steps:
                self.progress.close()
                print('[%s] Finished warmup' % datetime.now())
            info['epoch_cnt'] = self.epoch_cnt
            info['avg_loss_list'] = self.stats['avg_loss_list']
            return s_, reward, done, info

        if not (random.random() <= self.args.update_prob) or self.epoch_cnt == self.args.num_epochs:
            info['epoch_cnt'] = self.epoch_cnt
            info['avg_loss_list'] = self.stats['avg_loss_list']
            return s_, reward, done, info

        if self.local_step == 0:
            self.stats['loss_list'] = []
            self.time_start = time.time()
            self.progress = tqdm(
                range(self.args.steps_per_epoch),
                desc='Loss: None | L2 Weights: %12g | '
                     'L2 Grads: 0' % (
                        utils.get_weights_norm(self.enc.parameters(), norm_type=2.0)
                        ), total=self.args.steps_per_epoch, position=0, leave=True
            )

        self._update(info['buffers'])
        self.progress.update(1)
        self.progress.set_description(
            'Loss: %12g | L2 Weights: %12g | L2 Grads: %12g' % (
                self.stats['loss_list'][-1],
                utils.get_weights_norm(self.enc.parameters(), norm_type=2.0),
                utils.get_grads_norm(self.enc.parameters(), norm_type=2.0)
            ))

        if self.local_step == self.args.steps_per_epoch:
            self.epoch_cnt += 1
            self.progress.close()
            self.local_step = 0
            time_end = time.time()

            self.stats['avg_loss'] = np.mean(self.stats['loss_list'])
            self.stats['avg_loss_list'].append(self.stats['avg_loss'])
            if self.verbose:
                print('\nEpoch %3d | Loss: %12g | Time: %6.1f sec' % (
                    self.epoch_cnt, self.stats['avg_loss'],
                    time_end - self.time_start)
                      )

            # scheduler.step(avg_loss)
            self.scheduler.step()

        info['epoch_cnt'] = self.epoch_cnt
        info['avg_loss_list'] = self.stats['avg_loss_list']
        return s_, reward, done, info

    def _update(self, buffers):
        self.enc.train()
        self.opt.zero_grad()

        device = utils.get_device(self.enc)
        data = random.sample(buffers.tolist(), self.args.batch_size)
        batch = torch.cat(
            [torch.Tensor(np.array(data[i][0])).float().unsqueeze(0).to(device) / 255.0
             for i in range(self.args.batch_size)], dim=0
        )
        next_batch = torch.cat(
            [torch.Tensor(np.array(data[i][2])).float().unsqueeze(0).to(device) / 255.0
             for i in range(self.args.batch_size)], dim=0
        )
        x, x_ = self.enc(batch), self.enc(next_batch)
        loss = self.const_loss(x, x_, device)
        loss.backward()
        self.opt.step()
        self.stats['loss_list'].append(loss.item())
        self.local_step += 1
        self.global_step += 1


class TrainTransitionWrapper(gym.Wrapper):
    def __init__(self, env, args, enc, trans, verbose=True):
        gym.Wrapper.__init__(self, env)
        self.args = args
        self.enc = enc
        self.trans = trans
        self.verbose = verbose
        self.num_actions = env.action_space.n

        self.opt = optim.Adam(
            trans.parameters(),
            lr=args.learning_rate,
            weight_decay=args.weight_decay
        )
        self.scheduler = optim.lr_scheduler.StepLR(
            self.opt, step_size=9999, gamma=0.5
        )
        # TODO: step_size and gamma should come from the arguments

        self.stats = {
            'avg_loss_list': []
        }
        self.mse_loss = torch.nn.MSELoss(reduction='mean')
        self.local_step = 0
        self.global_step = 0
        self.epoch_cnt = 0
        self.warmup_cnt = 0

    def reset(self, **kwargs):
        return self.env.reset(**kwargs)

    def step(self, a):
        s_, reward, done, info = self.env.step(a)

        if self.warmup_cnt == 0 and self.args.warmup_steps > 0:
            print('[%s] Started warmup' % datetime.now())
            self.progress = tqdm(range(self.args.warmup_steps), desc='Warmup')

        if self.warmup_cnt < self.args.warmup_steps:
            self.warmup_cnt += 1
            self.progress.update(1)
            if self.warmup_cnt == self.args.warmup_steps:
                self.progress.close()
                print('[%s] Finished warmup' % datetime.now())
            info['epoch_cnt'] = self.epoch_cnt
            info['avg_loss_list'] = self.stats['avg_loss_list']
            return s_, reward, done, info

        if not (random.random() <= self.args.update_prob) or self.epoch_cnt == self.args.num_epochs:
            info['epoch_cnt'] = self.epoch_cnt
            info['avg_loss_list'] = self.stats['avg_loss_list']
            return s_, reward, done, info

        if self.local_step == 0:
            self.stats['loss_list'] = []
            self.time_start = time.time()
            self.progress = tqdm(
                range(self.args.steps_per_epoch),
                desc='Loss: None | L2 Weights: %12g | '
                     'L2 Grads: 0' % (
                        utils.get_weights_norm(self.enc.parameters(), norm_type=2.0)
                        ), total=self.args.steps_per_epoch, position=0, leave=True
            )

        self._update(info['buffers'])
        self.progress.update(1)
        self.progress.set_description(
            'Loss: %12g | L2 Weights: %12g | L2 Grads: %12g' % (
                self.stats['loss_list'][-1],
                utils.get_weights_norm(self.enc.parameters(), norm_type=2.0),
                utils.get_grads_norm(self.enc.parameters(), norm_type=2.0)
            ))

        if self.local_step == self.args.steps_per_epoch:
            self.epoch_cnt += 1
            self.progress.close()
            self.local_step = 0
            time_end = time.time()

            self.stats['avg_loss'] = np.mean(self.stats['loss_list'])
            self.stats['avg_loss_list'].append(self.stats['avg_loss'])
            if self.verbose:
                print('\nEpoch %3d | Loss: %12g | Time: %6.1f sec' % (
                    self.epoch_cnt, self.stats['avg_loss'],
                    time_end - self.time_start)
                      )

            # scheduler.step(avg_loss)
            self.scheduler.step()

        info['epoch_cnt'] = self.epoch_cnt
        info['avg_loss_list'] = self.stats['avg_loss_list']
        return s_, reward, done, info

    def _update(self, buffers):
        self.enc.train()
        self.opt.zero_grad()
        device = utils.get_device(self.enc)
        data = random.sample(buffers.tolist(), self.args.batch_size)
        batch, action_onehot, next_batch = process_samples(
            data, self.args.batch_size, self.num_actions, device
        )
        x, x_ = self.enc(batch), self.enc(next_batch)
        x_next = self.trans.transition(x, action_onehot)
        loss = self.mse_loss(x_next, x)
        loss.backward()
        self.opt.step()
        self.stats['loss_list'].append(loss.item())
        self.local_step += 1
        self.global_step += 1

def process_samples(data, batch_size, num_actions, device):
    batch = torch.cat(
        [torch.Tensor(np.array(data[i][0])).float().unsqueeze(0).to(device) / 255.0
         for i in range(batch_size)], dim=0
    )
    next_batch = torch.cat(
        [torch.Tensor(np.array(data[i][2])).float().unsqueeze(0).to(device) / 255.0
         for i in range(batch_size)], dim=0
    )
    action_onehot = torch.zeros((batch_size, num_actions)).to(device)
    for i in range(batch_size):
        action_onehot[i, data[i][1]] = 1
    return batch, action_onehot, next_batch