import os

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from joblib import Parallel, delayed
from termcolor import colored
from tqdm import tqdm

from icpe.feat import Feature
from icpe.model import Transformer
from icpe.MRP.boyan import BoyanChain
from icpe.MRP.cartpole import CartPole
from icpe.prompt import ContextGenerator, pro, simulate
from icpe.utils import compute_msve, set_seed


class Validation:
    def __init__(self, s: int, d: int, b: int, ns: np.ndarray, gamma: float,
                 mrp: str, feat_mode: str, expected: bool, instances: int = 5) -> None:
        '''
        param s: number of bins per feature if CartPole, otherwise number of states
        param d: feature dimension
        param b: batch size
        param ns: context lengths
        param gamma: discount factor
        param mrp: MRP type
        param feat_mode: feature mode
        param instances: number of validation instances
        '''
        self.instances = instances
        self.true_values = np.zeros(
            (instances, s ** 4, 1)) if mrp == 'cartpole' else np.zeros((instances, s, 1))
        self.steady_dists = np.zeros(
            (instances, s ** 4)) if mrp == 'cartpole' else np.zeros((instances, s))
        self.phis = np.zeros(
            (instances, s ** 4, d)) if mrp == 'cartpole' else np.zeros((instances, s, d))
        self.ctxts = {n: np.zeros((instances, b, 2 * d + 1, n)) for n in ns}

        print('generating validation instances...')
        for i in tqdm(range(instances)):
            if mrp == 'boyan':
                feat = Feature(d, s, feat_mode)
                weight = np.random.randn(d, 1)
                env = BoyanChain(s, gamma, weight, feat.phi)
            elif mrp == 'cartpole':
                feat = Feature(d, s ** 4, feat_mode)
                env = CartPole(s, gamma)
            else:
                raise NotImplementedError

            ctxt_generator = ContextGenerator(env, gamma, feat)
            # generate a batch of contexts
            self.true_values[i] = env.get_value()
            self.steady_dists[i] = env.get_steady_d()
            self.phis[i] = feat.phi
            for n in ns:
                self.ctxts[n][i] = ctxt_generator.generate_context(
                    n, b, False, expected)

    def run(self, tf: Transformer, n: int) -> float:
        '''
        param tf: Transformer
        param n: context length
        return: msves of the validation instances given context length n
        '''
        tf = tf.copy()
        tf.reset_ctxt_len(n)
        tf.eval()
        with torch.no_grad():
            msves = []
            for i in range(self.instances):
                # (b, 2*d+1, n)
                ctxt_batch = torch.from_numpy(self.ctxts[n][i])
                phi = torch.from_numpy(self.phis[i])  # (s, d)
                v_preds = Parallel(n_jobs=-1)(delayed(tf.fit_value_func)(c, phi)
                                              for c in ctxt_batch)
                msves.append(
                    np.mean([compute_msve(v.cpu().numpy(), self.true_values[i], self.steady_dists[i])
                             for v in v_preds])
                )
        return np.mean(msves)


def train(seed: int, name: str, cfg: dict):
    '''
    param seed: random seed
    param name: experiment name
    param cfg: configuration dictionary
    '''
    set_seed(seed)  # set the random seed

    # make log directories
    run_dir = os.path.join('.', 'log', name, str(seed))
    params_dir = os.path.join(run_dir, 'params')
    val_dir = os.path.join(run_dir, 'validation')
    os.makedirs(run_dir, exist_ok=True)
    os.makedirs(params_dir, exist_ok=True)
    os.makedirs(val_dir, exist_ok=True)
    if cfg['feature_mode'] == 'one-hot':
        if cfg['feature_dim'] != cfg['num_states']:
            print(
                colored(
                    f'Warning: Feature dimension must equal to the number of states in one-hot encoding. Force resetting feature dimension({cfg["feature_dim"]} -> {cfg["num_states"]})\n',
                    'red')
            )
            cfg['feature_dim'] = cfg['num_states']
    if cfg['enumerated_context']:
        if cfg['context_len'] != cfg['num_states']:
            print(
                colored(
                    f'Warning: Context length must equal to the number of states for enumerated context. Force resetting context length({cfg["context_len"]} -> {cfg["num_states"]})\n',
                    'red')
            )
            cfg['context_len'] = cfg['num_states']

    tf = Transformer(d=cfg['feature_dim'], n=cfg['context_len'],
                     l=cfg['layer_num'], activation=cfg['activation'],
                     mode=cfg['tf_mode'], constrained=cfg['constrained'])
    optimizer = optim.Adam(tf.parameters(), lr=cfg['lr'],
                           weight_decay=cfg['weight_decay'])

    criterion = nn.MSELoss()

    # make validation instances
    val_ns = np.arange(cfg['val_ctxt_start'],
                       cfg['val_ctxt_end'] + 1,
                       cfg['val_ctxt_step'])

    validator = Validation(cfg['num_cartpole_bins'] if cfg['mrp'] == 'cartpole' else cfg['num_states'],
                           cfg['feature_dim'],
                           cfg['batch_size'],
                           val_ns,
                           cfg['gamma'],
                           cfg['mrp'],
                           cfg['feature_mode'],
                           cfg['expected_feature'],
                           cfg['val_instances'])

    for task in tqdm(range(1, cfg['steps'] + 1)):
        if cfg['mrp'] == 'cartpole':
            # reinitialize the feature function
            feat = Feature(cfg['feature_dim'],
                           cfg['num_cartpole_bins'] ** 4,
                           cfg['feature_mode'])
            # CartPole's value function is never representable
            env = CartPole(cfg['num_cartpole_bins'], cfg['gamma'])
        elif cfg['mrp'] == 'boyan':
            # reinitialize the feature function
            feat = Feature(cfg['feature_dim'],
                           cfg['num_states'],
                           cfg['feature_mode'])
            # reinitialize the MRP
            if cfg['representable']:
                weight = np.random.randn(cfg['feature_dim'], 1)
                env = BoyanChain(n_states=cfg['num_states'], gamma=cfg['gamma'],
                                 weight=weight, phi=feat.phi)
            else:
                env = BoyanChain(
                    n_states=cfg['num_states'], gamma=cfg['gamma'])
        else:
            raise NotImplementedError

        ctxt_generator = ContextGenerator(env, cfg['gamma'], feat)
        ctxts = ctxt_generator.generate_context(cfg['context_len'],
                                                cfg['batch_size'],
                                                cfg['enumerated_context'],
                                                cfg['expected_feature'])
        if cfg['target'] == 'TD':
            # simulate the MRP for the TD update
            states, next_states, rewards = simulate(env, cfg['batch_size'])
            state_indices = env.get_feature_index(states)
            next_state_indices = env.get_feature_index(next_states)
            Z = [torch.from_numpy(pro(c, feat(s)))
                 for s, c in zip(state_indices, ctxts)]
            Zp = [torch.from_numpy(pro(c, feat(s)))
                  for s, c in zip(next_state_indices, ctxts)]
            Z = torch.stack(Z, dim=0).float()
            Zp = torch.stack(Zp, dim=0).float()
            rewards = torch.from_numpy(rewards).reshape(-1, 1).float()
            target = rewards + cfg['gamma'] * tf.pred_v(Zp).detach()
        elif cfg['target'] == 'MC':
            # simulate the MRP for the MC update
            states, _, _ = simulate(env, cfg['batch_size'])
            state_indices = env.get_feature_index(states)
            Z = [torch.from_numpy(pro(c, feat(s)))
                 for s, c in zip(state_indices, ctxts)]
            Z = torch.stack(Z, dim=0).float()
            target = Parallel(n_jobs=-1)(delayed(env.mc)(s, cfg['MC_steps'])
                                         for s in states)
            target = torch.FloatTensor(target).reshape(-1, 1)
        elif cfg['target'] == 'V':
            # simulate the MRP for the V update
            states, _, _ = simulate(env, cfg['batch_size'])
            state_indices = env.get_feature_index(states)
            Z = [torch.from_numpy(pro(c, feat(s)))
                 for s, c in zip(state_indices, ctxts)]
            Z = torch.stack(Z, dim=0).float()
            target = torch.FloatTensor(
                env.get_value()[np.array(state_indices, dtype=np.int32)]).reshape(-1, 1)
        else:
            raise NotImplementedError

        pred = tf.pred_v(Z)
        loss = criterion(pred, target)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if task % cfg['log_interval'] == 0:
            Ps, Qs = tf.to_numpy()

            msves = np.array([validator.run(tf, n) for n in val_ns])

            # save the parameters and validation results
            np.save(os.path.join(params_dir, f'params_{task}.npy'), (Ps, Qs))
            np.save(os.path.join(val_dir, f'val_{task}.npy'), (val_ns, msves))
