import os
import sys
import torch
import torch.nn as nn
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.nn.functional as F
import numpy as np
import argparse
import datetime
import pandas as pd
from env import MatPred, CyclicGroup, CyclicHMM, Grid, Dihedral, CyclicRealHMM, LinearDynamicalSystem
import wandb
from models import RNNAgent, TFAgent, LSTMAgent, TFConfig, Agent
from tqdm import tqdm
from data import generate_data, load_data
from args import parse_args
from loguru import logger

# load the environment from args
def init_env(args) :
    match args.env :
        case 'CyclicHard' : 
            env = CyclicHMM(args, args.num_envs, args.state_dim, args.action_n, args.action_n, args.length, args.alpha)
        case 'CyclicEasy' : env = CyclicGroup(args, args.num_envs, args.state_dim, args.action_n, args.action_n, args.length)
        case 'HMM' : env = MatPred(args, args.num_envs, args.action_n, args.length, args.state_dim, args.action_n, mode = 'normal', rank = args.rank, perturb = args.perturb)
        case 'MatRot' : env = MatPred(args, args.num_envs, args.action_n, args.length, args.state_dim, args.action_n, mode = 'simplified')
        case 'Grid' : env = Grid(args, args.num_envs, args.state_dim, args.action_n, args.action_n, args.length)
        case 'Dihedral' : env = Dihedral(args, args.num_envs, args.state_dim, args.action_n, args.action_n, args.length)
        case 'CyclicRealHMM' : env = CyclicRealHMM(args, args.num_envs, args.state_dim, args.action_n, args.action_n, args.length, eps = args.eps)
        case 'LinearDynamicalSystem': env = LinearDynamicalSystem(args, args.num_envs, args.state_dim, args.action_n, args.action_n, args.length)
    return env

# load the agent from args
def init_agent(args) :
    args.output_size = args.state_dim if args.env in ('MatRot', 'HMM', 'CyclicEasy', 'Grid', 'Dihedral', 'CyclicRealHMM') else args.action_n
    if args.env == 'LinearRegression' : args.output_size = 1
    args.input_size = 4
    args.input_size += args.action_n if args.env not in ('LinearDynamicalSystem',) else args.state_dim
    # determine the input dimension of the network
    if args.block_size > 0 : 
        args.net_input_size = max(args.input_size, args.output_size)
    else :
        args.net_input_size = args.input_size
    match args.agent :
        case 'MLP' :
            assert args.block_size > 0
            agent = Agent(args.block_size * args.input_size, args.output_size).to(args.device)
        case 'RNN' : agent = RNNAgent(args.device, args.net_input_size, args.hidden_dim, args.num_layers, args.output_size).to(args.device)
        case 'LSTM' : agent = LSTMAgent(args.device, args.net_input_size, args.hidden_dim, args.num_layers, args.output_size).to(args.device)
        case 'TF' : 
            if args.tf_model == 'gpt2' : agent = TFAgent(args, args.net_input_size, args.output_size, TFConfig(args.hidden_dim, args.num_heads, args.dropout, args.mlp_layers, True, args.num_layers, args.nn_max_len, pos_embed_type = args.pos_embed)).to(args.device)
            else : agent = TFAgent(args, args.net_input_size, args.output_size, TFConfig(args.hidden_dim, args.num_heads, args.dropout, args.mlp_layers, True, args.num_layers, args.nn_max_len, pos_embed_type = args.pos_embed), mode = 'scratch').to(args.device)
    return agent

# evaluation process
def evaluate(args, environ, len, agent, filter) : 
    s, _ = environ.reset()
    input = torch.zeros(args.num_envs, len + 1, args.input_size).to(args.device)
    input[:,:,-1] = torch.ones(args.num_envs, len + 1)
    input[:,:,-2] = torch.cos(torch.tile(torch.arange(len + 1) * torch.pi * 0.25 / args.nn_max_len, (args.num_envs, 1)))
    input[:,:,-3] = torch.sin(torch.tile(torch.arange(len + 1) * torch.pi * 0.25 / args.nn_max_len, (args.num_envs, 1)))
    input[:,0,-4] = torch.ones(args.num_envs)
    if args.block_size > 0 :
        network_input = torch.zeros(args.num_envs, len + (len - 1) // args.block_size + 3, args.net_input_size).to(args.device)
        cur_step = 0
    cur_loss = np.zeros((args.num_envs, len + 1))
    pred_loss = np.zeros(len + 1) - 1
    cur_step = 0
    with torch.no_grad() :
        for step in range(len + 1) :
            if args.block_size > 0 and args.agent == 'MLP' : # MLP agent
                mlp_input = torch.zeros((args.num_envs, args.block_size * args.input_size)).to(args.device)
                feedback_length = min(step + 1, args.block_size)
                mlp_input[:, :feedback_length * args.input_size] = input[:, step - feedback_length + 1 : step + 1, :].reshape(args.num_envs, -1)
                output = agent(mlp_input)
            elif args.block_size > 0 : # block CoT enabled
                cur_step += 1
                network_input[:,cur_step,:] = F.pad(input[:, step, :], (0, args.net_input_size - args.input_size), 'constant', 0)
                output = agent(network_input[:, :cur_step + 1, :])[:,-1,:]
            else :
                output = agent(input[:, :step + 1, :])[:,-1,:]
            if args.env == 'MatRot' :
                cur_loss[:,step] = np.linalg.norm(s - filter(output).cpu().numpy(), ord = 2, axis = 1)
            elif args.env == 'LinearDynamicalSystem' :
                expected_y = environ.emit_obs_distribution() # computed by Kalman filter
                cur_loss[:,step] = np.sqrt(np.sum((expected_y  - filter(output).cpu().numpy()) ** 2, axis = 1)) / np.clip(np.linalg.norm(expected_y, ord = 2, axis = 1), 1, None)
            else :
                cur_loss[:,step] = np.linalg.norm(environ.emit_obs_distribution() - filter(output).cpu().numpy(), ord = 1, axis = 1) / 2

            # compute the loss at prediction stage for Cyclic-HARD
            pred_mask = environ.prediction_stage()
            if (pred_mask == True).any() :
                pred_loss[step] = np.dot(pred_mask, cur_loss[:,step]) / np.sum(pred_mask)
            if step == len : break
            action = environ.emit_obs(real = args.real)
            if args.env not in ('LinearRegression', 'LinearDynamicalSystem') : 
                action_rep = F.one_hot(torch.LongTensor(action), args.action_n)
            else :
                action_rep = torch.FloatTensor(action)
            input[:, step + 1, :args.action_n] = torch.clone(action_rep)
            if args.block_size > 0 and (step + 1) % args.block_size == 0 :
                network_input[:,cur_step + 1,:] = F.pad(output, (0, args.net_input_size - args.output_size), 'constant', 0)
                cur_step += 1
            s, _ = environ.step(action)

    return np.mean(cur_loss, axis = 0), pred_loss

# setup the distributed data parallel
def ddp_setup(args, rank, world_size, backend='nccl', port = 12227):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = str(port)
    
    # initialize the process group
    dist.init_process_group(backend=backend, rank=rank, world_size=world_size)
    np.random.seed(args.seed + rank * 2398)
    torch.manual_seed(args.seed + rank * 2398)
    torch.backends.cudnn.deterministic = True

def run(rank, world_size, args, environ):
    # set up the parallel environment
    if world_size > 1 : ddp_setup(args, rank, world_size, port = args.port)
    else : assert rank == 0
    args.num_traj = args.num_traj // world_size
    args.batch_size = args.batch_size // world_size
    args.device = rank
    assert args.device < torch.cuda.device_count()
    print(f"Running rank {rank} on device {args.device}.")
    logger.info(f"Running rank {rank} on device {args.device}.")
    args.device = torch.device("cuda:" + str(args.device) if torch.cuda.is_available() else "cpu")

    # initialize the agent
    if world_size > 1 :
        if args.tf_model == 'gpt2' : agent = DDP(init_agent(args), device_ids = [rank], find_unused_parameters = True)
        elif args.tf_model == 'scratch' : agent = DDP(init_agent(args), device_ids = [rank], find_unused_parameters = False)
        else : raise NotImplementedError
        agent_save = agent.module
    else :
        agent = init_agent(args)
        agent_save = agent

    # create directory for saving model
    load_model = True
    if args.directory is not None :
        if rank == 0 and os.path.exists(args.directory) is False :
            os.mkdir(args.directory)
            load_model = False
            logger.info(f"model path: {args.directory}")
        args.directory += "/state.pth"

    # fill in your wandb information
    if args.wandb and rank == 0 :
        wandb.init(
            name="XX",
            project="HMM",
            entity="YY",
            resume="allow"
        )
    logger.info(f"Number of parameters: {sum(p.numel() for p in agent.parameters() if p.requires_grad)}") 
    print(f"Number of parameters: {sum(p.numel() for p in agent.parameters() if p.requires_grad)}")

    start_time = datetime.datetime.now()
    # whether to generate fresh data for each epoch
    if args.fresh_data_per_epoch is False :
        data, label = generate_data(args, environ, args.length)
        end_time = datetime.datetime.now()
        print("Data generation time at rank {}: {}".format(rank, (end_time - start_time).seconds))
        logger.info("Data generation time at rank {}: {}".format(rank, (end_time - start_time).seconds))

    optimizer = torch.optim.AdamW(agent.parameters(), args.lr)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size = args.lr_decay_gap, gamma = args.lr_decay_rate)

    epoch_start = 0
    optim_step = 0
    inds = np.arange(args.num_traj)
    filter = nn.Identity() if args.loss == 'mse' else nn.Softmax(dim = 1)
    lossfunc = nn.MSELoss() if args.loss == 'mse' else nn.CrossEntropyLoss()

    # curriculum learning setup
    if args.curriculum is False :
        cur_len = args.length
    else :
        cur_len = 2 ** (args.num_layers - 1) if args.curriculum_double else args.curriculum_init - args.curriculum_step
    
    # block CoT setup
    if args.block_size > 0 : 
        assert args.curriculum == False, "curriculum not correctly implemented for block CoT"
        optimize_iteration = (cur_len - 1) // args.block_size + 1
        period_start = [o * args.block_size + o + 1 for o in range(optimize_iteration)]
        period_end = [period_start[o] + args.block_size for o in range(optimize_iteration)]
        period_len = [args.block_size for o in range(optimize_iteration)]
        period_end[-1] = cur_len + optimize_iteration
        period_len[-1] = period_end[-1] - period_start[-1]
    
    # load the model from checkpoint
    if load_model and args.load :
        load_state = torch.load(args.directory)
        epoch_start = load_state['epoch']
        if 'cur_len' in load_state.keys() :
            cur_len = load_state['cur_len']
        agent_save.load_state_dict(load_state['model'])
        optimizer.load_state_dict(load_state['optimizer'])
        scheduler.load_state_dict(load_state['scheduler'])
        logger.info(f"Model loaded on rank {rank}")
    
    # training loop
    for epoch in range(epoch_start, args.epoch) :
        np.random.shuffle(inds)

        if args.curriculum and epoch % args.curriculum_update_freq == 0 and cur_len < args.length :
            if args.curriculum_double : cur_len = min(cur_len * 2, args.length)
            else : cur_len = min(cur_len + args.curriculum_step, args.length)

        if rank == 0 : 
            print("Entering Epoch {} for length {}".format(epoch, cur_len))
            logger.info("Entering Epoch {} for length {}".format(epoch, cur_len))
        if args.fresh_data_per_epoch :
            start_time = datetime.datetime.now()
            data, label = generate_data(args, environ, args.length)
            end_time = datetime.datetime.now()
            if rank == 0 :
                print("Fresh data generation time: {}".format((end_time - start_time).seconds))
                logger.info("Fresh data generation time: {}".format((end_time - start_time).seconds))
        updated_interval = range(cur_len) if args.env != 'AutoRegression' else range(environ.dof - 1, cur_len, 1)
        size_updated_interval = len(updated_interval)

        start_time = datetime.datetime.now()
        for t in range(0, args.num_traj, args.batch_size) :
            b_inds = inds[t : t + args.batch_size]
            input, target = load_data(args, data, label, b_inds)
            if args.real : target = input[:, 1:, :args.output_size] # next-observation prediction
            if args.block_size > 0 : 
                # the (i*block_size)-th entry represents the state at the beginning of the i-th block
                network_input = torch.zeros(b_inds.shape[0], cur_len + optimize_iteration + 1, args.net_input_size).to(args.device)

            if args.agent == 'MLP' : # use lastest block_size entries to predict the next entry
                # extract the latest block_size entries of input
                for i in range(cur_len) :
                    mlp_input = torch.zeros((b_inds.shape[0], args.block_size * args.input_size)).to(args.device)
                    feedback_length = min(i + 1, args.block_size)
                    mlp_input[:, :feedback_length * args.input_size] = input[:, i - feedback_length + 1 : i + 1, :].reshape(b_inds.shape[0], -1)
                    output = agent(mlp_input)
                    loss = lossfunc(output, target[:, i, :])
                    optimizer.zero_grad()
                    loss.backward()
                    grad_norm = nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm)
                    optimizer.step()
            elif args.agent != 'TF' : # RNN, LSTM
                output = agent(input[:, updated_interval, :])
                loss = lossfunc(output.view(b_inds.shape[0] * size_updated_interval, -1), 
                                target[:, updated_interval, :].view(b_inds.shape[0] * size_updated_interval, -1))  
                optim_step += 1
                optimizer.zero_grad()
                loss.backward()
                grad_norm = nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm)
                optimizer.step()
            elif args.block_size == 0 : # Transformer without block CoT
                output = agent(input[:, updated_interval, :]) # shape (bs, cur_len, output_size)
                loss = lossfunc(output.view(b_inds.shape[0] * size_updated_interval, -1), 
                            target[:, updated_interval, :].view(b_inds.shape[0] * size_updated_interval, -1))  
                
                # learning rate warmup
                optim_step += 1
                if optim_step < args.warmup :
                    warmup_lr = args.lr * optim_step / args.warmup
                    for param_group in optimizer.param_groups :
                        param_group['lr'] = warmup_lr
                optimizer.zero_grad()
                loss.backward()
                grad_norm = nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm)
                optimizer.step()
            else : # block CoT
                for o in range(optimize_iteration) :
                    network_input[:, period_start[o] : period_end[o], :] = F.pad(
                        input[:, o * args.block_size : o * args.block_size + period_len[o], :], (0, args.net_input_size - args.input_size), 
                        'constant', 0
                    )
                    output = agent(network_input[:, :period_end[o],:]) # shape (bs, length, output_size)
                    loss = lossfunc(output[:, period_start[o] : period_end[o], :].reshape(b_inds.shape[0] * period_len[o], -1), 
                                target[:, o * args.block_size : o * args.block_size + period_len[o], :].reshape(b_inds.shape[0] * period_len[o], -1))
                    
                    network_input[:, period_end[o], :] = F.pad(
                        output[:, -1, :].detach(), (0, args.net_input_size - args.output_size), 
                        'constant', 0
                    )

                    optim_step += 1
                    if optim_step < args.warmup :
                        warmup_lr = args.lr * optim_step / args.warmup
                        for param_group in optimizer.param_groups :
                            param_group['lr'] = warmup_lr
                    optimizer.zero_grad()
                    loss.backward()
                    grad_norm = nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm)
                    optimizer.step()

            # logging loss function
            if (t // args.batch_size) % args.log_every_steps == 0 and rank == 0 :
                if args.wandb :
                    wandb.log({
                        "loss" : loss.item(),
                        "grad_norm" : grad_norm
                    }, step = epoch * args.num_traj + t)
                print("Epoch: {}, Step: {}, Loss: {}, Grad Norm: {} Time: {}".format(epoch, t, loss.item(), grad_norm, (datetime.datetime.now() - start_time).seconds))
                logger.info("Epoch: {}, Step: {}, Loss: {}, Grad Norm: {} Time: {}".format(epoch, t, loss.item(), grad_norm, (datetime.datetime.now() - start_time).seconds))
                start_time = datetime.datetime.now()
                if args.directory is not None :
                    state = {'epoch': epoch,
                             'cur_len': cur_len,
                             'model': agent_save.state_dict(),
                             'optimizer': optimizer.state_dict(),
                             'scheduler': scheduler.state_dict()}
                    torch.save(state, args.directory)
        if rank == 0 : 
            print("Epoch {} ends, Time: {}".format(epoch, (datetime.datetime.now() - start_time).seconds))
            logger.info("Epoch {} ends, Time: {}".format(epoch, (datetime.datetime.now() - start_time).seconds))
            if args.directory is not None :
                state = {'epoch': epoch,
                         'cur_len': cur_len,
                         'model': agent_save.state_dict(),
                         'optimizer': optimizer.state_dict(),
                         'scheduler': scheduler.state_dict()}
                torch.save(state, args.directory)

        if optim_step >= args.warmup :
            scheduler.step()

        if args.noeval or rank > 0 : continue
        eval_val = evaluate(args, environ, cur_len, agent, filter)
        for tt in range(-1, cur_len, args.log_evaluate_step) :
            if args.wandb :
                wandb.log(
                    {
                        "eval_loss_" + str(tt) : eval_val[0][tt + 1]
                        #"tf_eval_loss_" + str(tt) : tf_loss[tt]
                    }, step = epoch * args.num_traj + t
                )
            if args.env != 'CyclicHard' : 
                print("Epoch: {} Step: {}".format(epoch, tt + 1) + " evaluation loss: {}".format(eval_val[0][tt + 1]))
                logger.info("Epoch: {} Step: {}".format(epoch, tt + 1) + " evaluation loss: {}".format(eval_val[0][tt + 1]))
            else :
                print("Epoch: {} Step: {}".format(epoch, tt + 1) + " avg evaluation loss: {} state evaluation loss: {}".format(eval_val[0][tt + 1], eval_val[1][tt + 1]))
                logger.info("Epoch: {} Step: {}".format(epoch, tt + 1) + " avg evaluation loss: {} state evaluation loss: {}".format(eval_val[0][tt + 1], eval_val[1][tt + 1]))
    
if __name__ == "__main__" :
    args = parse_args()
    print(args)
    logger.info(args)
    np.random.seed(args.seed)
    environ = init_env(args)

    # gpu setup
    strss = ",".join([str(i) for i in range(args.gpu_bias, min(8, args.gpu_bias + args.world_size))])
    os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(i) for i in range(args.gpu_bias, min(8, args.gpu_bias + args.world_size))])
    print("CUDA devices: {}".format(strss))
    logger.info("CUDA devices: {}".format(strss))

    if args.world_size > 1 :
        mp.spawn(run, args=(args.world_size, args, environ), nprocs=args.world_size, join=True)
        #dist.destroy_process_group()
    else :
        run(0, 1, args, environ)