import os, sys
import math
import numpy as np
import numpy.random as npr
import pandas as pd
import matplotlib.pyplot as plt
import torch.optim as optim
from torch import nn
import torch
import torchcde
from physiopro.network.contiformer import AttrDict, EncoderLayer
import pickle
from typing import Dict
from utils import PICalibData, MeanStdevFilter
from utils import check_or_make_folder

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

########################################################
####################  ContiFormer ######################
########################################################

class ContiFormer(nn.Module):
    def __init__(self, obs_dim, device, params: Dict):
        super(ContiFormer, self).__init__()
        args_ode = {
            'use_ode': True, 'actfn': 'tanh', 'layer_type': 'concat', 'zero_init': True,
            'atol': params["atol"], 'rtol': params["rtol"], 'method': params["method"], 
            'regularize': False, 'approximate_method': 'bilinear', 'nlinspace': 1, 
            'linear_type': 'before', 'interpolate': 'linear', 'itol': 1e-2
        }
        args_ode = AttrDict(args_ode)
        num_layers = params['num_layers']
        n_head = params['n_head']
        d_k = d_v = params['d_k']
        d_model = params['n_head'] * params['d_k']
        d_inner = params['num_nodes']
        dropout = params["dropout"]
        # d_model == n_head * d_k; # d_model == n_head * d_v
        self.encoder = EncoderLayer(d_model=d_model, d_inner=d_inner, n_head=n_head, \
                                    d_k=d_k, d_v=d_v, num_layers=num_layers, args=args_ode, \
                                        dropout=dropout).to(device)
        self.lin_in = nn.Linear(obs_dim, 16).to(device)  # [B, L, obs_dim] -> [B, L, 16]
        self.lin_out = nn.Linear(16, obs_dim).to(device)  # [B, L, 16] -> [B, L, obs_dim]

        self.position_vec = torch.tensor(
            [math.pow(10000.0, 2.0 * (i // 2) / 16) for i in range(16)])  # [16]
        self.batch_size = params["batch_size"] # 64

    def temporal_enc(self, time):  # time: [B, L]
        result = time.unsqueeze(-1) / self.position_vec.to(time.device)  # [B, L, 16]
        result[:, :, 0::2] = torch.sin(result[:, :, 0::2])
        result[:, :, 1::2] = torch.cos(result[:, :, 1::2])

        return result  # [B, L, 16]

    def pad_input(self, input, t0, tmax):
        # input: [B, L, 16], t0: [L]
        input_last = input[:, -1:, :]  # [B, 1, 16]
        input = torch.cat((input, input_last), dim=1)  # [B, L+1, 16]
        t0 = torch.cat((t0, torch.tensor([tmax]).to(t0.device)), dim=0)  # [L+1]
        return input, t0

    def forward(self, samples, orig_ts, **kwargs):  # samples: [B, L, obs_dim+1] L=50 
        if kwargs.get('is_train', False):

            bs, ls = samples.shape[0], len(orig_ts) # bs 200
            sample_idx = npr.choice(bs, self.batch_size, replace=False)
            samples = samples[sample_idx, ...]  # [B=64, L, obs_dim+1]

            t0 = samples[..., -1]  # [B, L] select time dimension 
            input = self.lin_in(samples[..., :-1])  # [B, L, 16]
            input = (input + self.temporal_enc(t0)).float()  # [B, L, 16]

            # pad with last input from input and max time step of extrapolation that is 6*pi
            _input, _t0 = self.pad_input(input, t0[0], kwargs['tmax'])  # [B, L+1, 16], [L+1]
            
            X = torchcde.LinearInterpolation(_input, t=_t0)  # interpolation object
            input = X.evaluate(orig_ts).float()  # [B, L, 16] L=150
            orig_ts = torch.tensor(orig_ts).to(input.device)  # [L] L=150

            mask = torch.zeros(self.batch_size, ls, 1).to(input.device)  # [B, L, 1]

            # orig_ts sent to encoder is [B,L]
            out, _ = self.encoder(input, orig_ts.unsqueeze(0).repeat(self.batch_size, 1).float(),
                                  mask=mask.bool())  # out: [B=64, L=150, 16]
            return self.lin_out(out), sample_idx  # [B, L, obs_dim], sample_idx: [B]
        else:
            bs, ls = samples.shape[0], len(orig_ts)
            t0 = samples[..., -1]  # [B, L]
            input = self.lin_in(samples[..., :-1])  # [B, L, 16]
            input = (input + self.temporal_enc(t0)).float()  # [B, L, 16]

            _input, _t0 = self.pad_input(input, t0[0], kwargs['tmax'])  # [B, L+1, 16], [L+1]

            X = torchcde.LinearInterpolation(_input, t=_t0)
            input = X.evaluate(orig_ts).float()  # [B, L, 16]
            orig_ts = torch.tensor(orig_ts).to(input.device)  # [L]

            mask = torch.zeros(bs, ls, 1).to(input.device)  # [B, L, 1]
            out, _ = self.encoder(input, orig_ts.unsqueeze(0).repeat(bs, 1).float(), mask=mask.bool())  # [B, L, 16]
            return self.lin_out(out), None  # [B, L, obs_dim], None

    def calculate_loss(self, out, target):
        pred_x, idx = out  # pred_x: [B, L, obs_dim]
        target_x, _, _ = target  # target_x: [B, L, obs_dim]
        if idx is not None:
            return ((pred_x - target_x[idx, ...]) ** 2).sum()
        else:

            return ((pred_x - target_x) ** 2).sum()

class RunningAverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self, momentum=0.99):
        self.momentum = momentum
        self.reset()

    def reset(self):
        self.val = None
        self.avg = 0

    def update(self, val):
        if self.val is None:
            self.avg = val
        else:
            self.avg = self.avg * self.momentum + val * (1 - self.momentum)
        self.val = val

class Trainer:
    def __init__(self, params: Dict):
        self.params = params
        self.input_filter: MeanStdevFilter = params['input_filter']

    def train(self, dataset: PICalibData):

        ########################################################
        ####################  TRAINING LOOP ####################
        ########################################################
        best_val = np.inf
        best_model = None

        model = ContiFormer(self.params['input_dim'], device, self.params)
        optimizer = optim.Adam(model.parameters(), lr=self.params['lr'])

        loss_meter = RunningAverageMeter()  
        ntotal = self.params['train_horizon'] # total number of timesteps
        val_total = self.params['val_horizon']
        ntrain = dataset.norm_train_X.shape[0]
        ntest = dataset.norm_val_X.shape[0]
        tmax = dataset.timesteps_val[-1]

        train_trajs = dataset.norm_train_X.to(device) # train_horizon
        train_target = dataset.norm_train_Y.to(device) # val_horizon
        test_trajs = dataset.norm_val_X.to(device) # train_horizon
        test_target = dataset.norm_val_Y.to(device) # val_horizon
        orig_ts = dataset.timesteps_val # val_horizon

        test_idx = npr.choice(int(ntotal * 1), int(self.params['nsample']*ntotal), replace=False)
        test_idx = sorted(test_idx.tolist())

        folder = f"./seed{self.params['seed']}_{self.params['dataset_name']}_{self.params['niters']}niters_{self.params['lr']}lr_({self.params['num_nodes']}nodes_{self.params['num_layers']})layers_{self.params['n_head']}nhead_{self.params['d_k']}dk_{self.params['train_horizon']}trHz_{self.params['val_horizon']}valHz_{self.params['interp_horizon']}intHz_{self.params['batch_size']}bs_{self.params['delta_t']}deltaT_{self.params['nsample']}nsample"
        check_or_make_folder(folder)
        # saving dataset and params
        data_dict = {"dataset": dataset, "params": self.params}
        with open(folder + "/data_dict.pkl", "wb") as f:
            pickle.dump(data_dict, f)

        for itr in range(0, self.params["niters"] + 1):
            # Training loop taken from https://github.com/microsoft/SeqML/blob/main/ContiFormer/spiral.py

            optimizer.zero_grad()
            # from train horizon 
            idx = npr.choice(int(ntotal * 1), int(self.params['nsample']*ntotal), replace=False) 
            idx = sorted(idx.tolist())
            out_idx = npr.choice(int(val_total * 1), int(self.params['nsample']*val_total), replace=False) 
            out_idx = sorted(out_idx.tolist())            

            samp_trajs = train_trajs[:, idx, :]
            samp_ts = torch.tensor(orig_ts[idx]).to(samp_trajs.device)
            samp_ts = samp_ts.reshape(1, -1, 1).repeat(ntrain, 1, 1)
            samp_trajs = torch.cat((samp_trajs, samp_ts), dim=-1).float()

            # tamx is end timestep of validation 
            out = model(samp_trajs, orig_ts, idx=idx, is_train=True, tmax=tmax)
            out = (out[0][:,out_idx,:], out[1])
            train_target_ = train_target[:,out_idx,:]
            loss = model.calculate_loss(out, (train_target_,None,None))
            loss.backward()
            optimizer.step()
            loss_meter.update(loss.item())

            print('Iter: {}, running loss: {:.4f}'.format(itr, loss_meter.avg))

            #ckpt_path = os.path.join(args.train_dir, f'ckpt_{args.model_name}.pth')

            # test one iteration
            with torch.no_grad():
                samp_trajs = test_trajs[:, test_idx, :] # [100, 50, 2]
                samp_ts = torch.tensor(orig_ts[test_idx]).to(samp_trajs.device)
                samp_ts = samp_ts.reshape(1, -1, 1).repeat(ntest, 1, 1)
                samp_trajs = torch.cat((samp_trajs, samp_ts), dim=-1).float() # [100, 50, 3]

                pred_x = model(samp_trajs, orig_ts, idx=test_idx, tmax=tmax)[0]
                mae = torch.abs(pred_x - test_target).sum(dim=-1).mean()
                rmse = torch.sqrt(((pred_x - test_target) ** 2).sum(dim=-1).mean())
                print('Iter: {}, MAE: {:.4f}, RMSE: {:.4f}'.format(itr, mae.item(), rmse.item()))

                if mae.item() < best_val:
                    best_val = mae.item()

                    with torch.no_grad():
                        samp_trajs = test_trajs[:, test_idx, :]
                        samp_ts = torch.tensor(orig_ts[test_idx]).to(samp_trajs.device)
                        samp_ts = samp_ts.reshape(1, -1, 1).repeat(ntest, 1, 1)
                        samp_trajs = torch.cat((samp_trajs, samp_ts), dim=-1).float()

                        pred_x = model(samp_trajs, orig_ts, idx=test_idx, tmax=tmax)[0]

                        xs_pos = pred_x[0][:pred_x.shape[1] // 2, :]
                        xs_neg = pred_x[0][pred_x.shape[1] // 2 - 1:, :]

                    print(f"################### Best Model Found ###################")
                    print(f"Saving Best Model at Iteration: {itr}")
                    torch.save({
                        'model_state_dict': model.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'train_idx': idx,
                        'test_idx': test_idx,
                        'itr': itr,
                    }, folder + f"/best_wghts.pth")
                    #log.info('Stored ckpt at {}'.format(ckpt_path))

            if self.params["visualize"] and itr % 50 == 0:
                with torch.no_grad():
                    # sample from trajectorys' approx. posterior
                    samp_trajs = test_trajs[:, test_idx, :]
                    samp_ts = torch.tensor(orig_ts[test_idx]).to(samp_trajs.device)
                    samp_ts = samp_ts.reshape(1, -1, 1).repeat(ntest, 1, 1)
                    samp_trajs = torch.cat((samp_trajs, samp_ts), dim=-1).float()

                    pred_x = model(samp_trajs, orig_ts, idx=test_idx, tmax=tmax)[0]

                    xs_pos = pred_x[0][:pred_x.shape[1] // 2, :]
                    xs_neg = pred_x[0][pred_x.shape[1] // 2 - 1:, :]

                    xs_pos = xs_pos.cpu().numpy()
                    xs_neg = xs_neg.cpu().numpy()

                    orig_traj = test_target[0].cpu().numpy()
                    samp_traj = samp_trajs[0].cpu().numpy()

                    def tohex(rgb):
                        hex_r = hex(rgb[0])[2:].upper()  
                        hex_g = hex(rgb[1])[2:].upper()
                        hex_b = hex(rgb[2])[2:].upper()
                        hex_r0 = hex_r.zfill(2)  
                        hex_g0 = hex_g.zfill(2)
                        hex_b0 = hex_b.zfill(2)
                        return '#' + hex_r0 + hex_g0 + hex_b0  


                    color = {
                        'g': tohex((95, 206, 64)),
                        'r': tohex((234, 60, 51)),
                        'b': tohex((48, 111, 215))
                    }
                    plt.figure()

                    plt.plot(orig_traj[:, 0], orig_traj[:, 1],
                                color['g'], label='True Trajectory', linewidth=1.5)
                    plt.plot(xs_pos[:, 0], xs_pos[:, 1], color['b'],
                                label='Interpolation', linewidth=1.5)
                    plt.plot(xs_neg[:, 0], xs_neg[:, 1], color['r'],
                                label='Extrapolation', linewidth=1.5)
                    plt.scatter(samp_traj[:, 0], samp_traj[:, 1], color=color['g'],
                                label='Sampled Data', s=10)
                    plt.scatter(xs_pos[:, 0], xs_pos[:, 1], color=color['b'],
                                label='Prediction', s=10)
                    plt.axis('off')
                    plt.savefig(folder + f'/vis_{itr}iter.png', dpi=500)
                    print("Saved visualization figure")     

    def test(self, dataset: PICalibData):

        folder = f"./seed{self.params['seed']}_{self.params['dataset_name']}_{self.params['niters']}niters_{self.params['lr']}lr_({self.params['num_nodes']}nodes_{self.params['num_layers']})layers_{self.params['n_head']}nhead_{self.params['d_k']}dk_{self.params['train_horizon']}trHz_{self.params['val_horizon']}valHz_{self.params['interp_horizon']}intHz_{self.params['batch_size']}bs_{self.params['delta_t']}deltaT_{self.params['nsample']}nsample"

        train_trajs = dataset.norm_train_X.to(device) # train_horizon
        train_target = dataset.norm_train_Y.to(device) # val_horizon
        test_trajs = dataset.norm_val_X.to(device) # train_horizon
        test_target = dataset.norm_val_Y.to(device) # val_horizon
        orig_ts = dataset.timesteps_val # val_horizon

        ntotal = self.params['train_horizon'] # total number of timesteps
        ntrain = dataset.norm_train_X.shape[0]
        ntest = dataset.norm_val_X.shape[0]
        tmax = self.params['tmax']

        model = ContiFormer(self.params['input_dim'], device, params=self.params)
        optimizer = optim.Adam(model.parameters(), lr=self.params['lr'])

        # loading model weights
        state_dict = torch.load(folder + f"/best_wghts.pth")
        optimizer.load_state_dict(state_dict['optimizer_state_dict'])
        model.load_state_dict(state_dict['model_state_dict'])
        test_idx = state_dict['test_idx'] # same irregularl sampling at the test time

        ##################### Val Evaluation #####################

        with torch.no_grad():
            samp_trajs = test_trajs[:, test_idx, :] # [100, 50, 2]
            samp_ts = torch.tensor(orig_ts[test_idx]).to(samp_trajs.device)
            samp_ts = samp_ts.reshape(1, -1, 1).repeat(ntest, 1, 1)
            samp_trajs = torch.cat((samp_trajs, samp_ts), dim=-1).float() # [100, 50, 3]

            pred_x_val = model(samp_trajs, orig_ts, idx=test_idx, tmax=tmax)[0]
            mae = torch.abs(pred_x_val - test_target).sum(dim=-1).mean()
            rmse = torch.sqrt(((pred_x_val - test_target) ** 2).sum(dim=-1).mean())
            print(f"MAE: {mae.item():.4f}, RMSE: {rmse.item():.4f}")

        torch.save({
            'pred': pred_x_val.detach().cpu().numpy(),
            'target': test_target.detach().cpu().numpy(),
        }, folder + "/pred.pkl") 
        torch.cuda.empty_cache()

        ##################### Train Evaluation #####################
        
        # TODO: EDIT THE CODE TO SAVE TRAIN INDICES AS WELL 
        interp_hz = 75
        idx = npr.choice(int(ntotal * 1), int(self.params['nsample']*ntotal), replace=False) 
        idx = sorted(idx.tolist())
        b_sz = 500
        quotient = train_trajs.shape[0] // b_sz
        remainder = train_trajs.shape[0] % b_sz
        all_batches = [b_sz]*quotient
        all_batches.append(remainder)

        all_preds = []    
        start = 0
        for batch_sz in all_batches:
            print(f"from {start} to {start+batch_sz}")
            with torch.no_grad():
                samp_trajs = train_trajs[start:start+batch_sz, idx, :] # [100, 50, 2]
                samp_ts = torch.tensor(orig_ts[idx]).to(samp_trajs.device)
                samp_ts = samp_ts.reshape(1, -1, 1).repeat(samp_trajs.shape[0], 1, 1) #ntrain
                samp_trajs = torch.cat((samp_trajs, samp_ts), dim=-1).float() # [100, 50, 3]

                pred_x_train = model(samp_trajs, orig_ts, idx=idx, tmax=tmax)[0]
                all_preds.append(pred_x_train)
            start = start + batch_sz    

        all_preds = torch.concatenate(all_preds, dim=0).to(device)
        mae = torch.abs(all_preds - train_target).sum(dim=-1).mean()
        rmse = torch.sqrt(((all_preds - train_target) ** 2).sum(dim=-1).mean())
        print(f"MAE: {mae.item():.4f}, RMSE: {rmse.item():.4f}")        

        ##################### SAVING ERRORS #####################
        self.save_errors(all_preds[:,:interp_hz], train_target[:,:interp_hz], "train")
        self.save_errors(pred_x_val, test_target, "val")


    def moving_average(self, x, window=10):
        """
        x: np.ndarray of shape (B, T, D)
        window: int, size of the moving average window
        returns: np.ndarray of shape (B, T, D)
        """
        assert window >= 1
        assert x.ndim == 3, "Input must be a 3D array of shape (B, T, D)"
        
        B, T, D = x.shape
        left_pad = window // 2
        right_pad = window - 1 - left_pad  # asymmetric if window is even

        smoothed = np.empty_like(x)
        
        for b in range(B):
            for d in range(D):
                padded = np.pad(x[b, :, d], (left_pad, right_pad), mode='edge')
                smoothed[b, :, d] = np.convolve(padded, np.ones(window) / window, mode='valid')
        
        return smoothed

    def save_errors(self, pred_x: torch.Tensor, target: torch.Tensor, error_type: str):   

        folder = f"./seed{self.params['seed']}_{self.params['dataset_name']}_{self.params['niters']}niters_{self.params['lr']}lr_({self.params['num_nodes']}nodes_{self.params['num_layers']})layers_{self.params['n_head']}nhead_{self.params['d_k']}dk_{self.params['train_horizon']}trHz_{self.params['val_horizon']}valHz_{self.params['interp_horizon']}intHz_{self.params['batch_size']}bs_{self.params['delta_t']}deltaT_{self.params['nsample']}nsample"

        if self.params['ode_name'] == 'lorenz':
            pred_cols = ['x_pred','y_pred','z_pred']
            gr_cols = ['x_gr','y_gr','z_gr']
        elif self.params['ode_name'] == 'hopper':
            pred_cols = [f"S{dim}_pred" for dim in range(11)]
            gr_cols = [f"S{dim}_gr" for dim in range(11)]
        elif self.params['ode_name'] in ['walker2d','halfcheetah','pen_expert','hammer']:
            pred_cols = [f"S{dim}_pred" for dim in range(self.params['input_dim'])]
            gr_cols = [f"S{dim}_gr" for dim in range(self.params['input_dim'])]            

        # unnormalize pred and target
        preds_unnorm = self.input_filter.invert(pred_x.detach().cpu().numpy())
        targets_unnorm = self.input_filter.invert(target.detach().cpu().numpy())     

        # smoothing errors 
        #preds_unnorm = self.moving_average(preds_unnorm)

        val_hor = preds_unnorm.shape[1]
        no_of_trajs = preds_unnorm.shape[0]  

        horizon = np.stack([np.linspace(0,val_hor-1,val_hor)]*no_of_trajs)
        preds_unnorm = np.array(preds_unnorm.reshape(-1,self.params['input_dim']))
        targets_unnorm = np.array(targets_unnorm.reshape(-1,self.params['input_dim']))

        num_dims = self.params['input_dim']
        
        data_dict = {
            'horizon': horizon.reshape(-1,)
        }

        for i in range(num_dims):
            data_dict[pred_cols[i]] = preds_unnorm[:, i]

        for i in range(num_dims):
            data_dict[gr_cols[i]] = targets_unnorm[:, i]

        df = pd.DataFrame(data_dict)

        if os.path.exists(folder + f"/{self.params['dataset_name']}_errors_conti_0Iter_{error_type}.csv"):
            print("Error file already exists!")
        else:    
            print('Error file Saved')
            df.to_csv(folder + f"/{self.params['dataset_name']}_errors_conti_0Iter_{error_type}.csv", index=False)           
