from copy import deepcopy
import itertools
import numpy as np
import torch
from torch.optim import Adam
import gym
import time
import core_file as core
from stable_baselines3 import SAC # Just use stable_baselines3 to load source model
from tqdm import tqdm
import random
import os
from robosuite.wrappers import GymWrapper
import robosuite as suite
import csv
import math
from core.flow.real_nvp import RealNvp

# import baselines.environments.register as register
# import baselines.environments.init_path as init_path
# init_path.bypass_frost_warning()

import bark_ml.environments.gym

def set_random_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

class ReplayBuffer:
    """
    A simple FIFO experience replay buffer for SAC agents.
    """

    def __init__(self, obs_dim, act_dim, size,device):
        self.obs_buf = np.zeros(core.combined_shape(size, obs_dim), dtype=np.float32)
        self.obs2_buf = np.zeros(core.combined_shape(size, obs_dim), dtype=np.float32)
        self.act_buf = np.zeros(core.combined_shape(size, act_dim), dtype=np.float32)
        self.rew_buf = np.zeros(size, dtype=np.float32)
        self.done_buf = np.zeros(size, dtype=np.float32)
        self.ptr, self.size, self.max_size = 0, 0, size
        self.device = device

    def store(self, obs, act, rew, next_obs, done):
        self.obs_buf[self.ptr] = obs
        self.obs2_buf[self.ptr] = next_obs
        self.act_buf[self.ptr] = act
        self.rew_buf[self.ptr] = rew
        self.done_buf[self.ptr] = done
        self.ptr = (self.ptr+1) % self.max_size
        self.size = min(self.size+1, self.max_size)

    def sample_batch(self, batch_size=32):
        idxs = np.random.randint(0, self.size, size=batch_size)
        batch = dict(obs=self.obs_buf[idxs],
                     obs2=self.obs2_buf[idxs],
                     act=self.act_buf[idxs],
                     rew=self.rew_buf[idxs],
                     done=self.done_buf[idxs])
        return {k: torch.as_tensor(v, dtype=torch.float32,device=self.device) for k,v in batch.items()}
    

class sac:
    def __init__(self,env, 
                test_env, 
                src_env, 
                p_point1,
                p_point2,
                ac_kwargs=dict(), 
                total_steps= int(5e5), 
                log_folder=None,
                test_steps=50, 
                seed = None,
                replay_size=int(1e6), 
                gamma=0.99, 
                polyak=0.995, 
                lr=3e-4, 
                batch_size=256, 
                start_steps=100, 
                num_test_episodes=5, 
                source_model=None,
                device=None,
                sourceDataFile=None,
                state_flow=None,
                action_flow=None):
    
        self.env = env
        self.test_env = test_env
        self.total_steps = total_steps
        self.log_folder = log_folder
        self.test_steps = test_steps
        self.gamma = gamma
        self.polyak = polyak
        self.batch_size = batch_size
        self.start_steps = start_steps
        self.num_test_episodes = num_test_episodes
        self.device = device
        self.state_flow = state_flow
        self.action_flow = action_flow

        if sourceDataFile is not None:
            self.sourceState = np.load(sourceDataFile + 'state.npy').astype(np.float32)
            self.sourceAction = np.load(sourceDataFile + 'action.npy').astype(np.float32)
        else:
            self.sourceState = None
            self.sourceAction = None

        self.p_point1 = p_point1
        self.p_point2 = p_point2

        self.decoder = core.decoder_network(env.observation_space.shape[0], src_env.observation_space.shape[0], 256, device,outputScale=0.5)
        self.encoder = core.decoder_network(src_env.observation_space.shape[0], env.observation_space.shape[0], 256, device)
        self.action_decoder = core.decoder_network(env.action_space.shape[0], src_env.action_space.shape[0], 256, device,outputScale=0.5)
        self.action_encoder = core.decoder_network(src_env.action_space.shape[0], env.action_space.shape[0], 256, device,outputScale=1.0)
        self.variationa_net = core.decoder_network(src_env.observation_space.shape[0], env.observation_space.shape[0], 64, device, variation=True)
        self.variationa_action_net = core.decoder_network(src_env.action_space.shape[0], env.action_space.shape[0], 64, device,outputScale=1.0, variation=True)

        self.decoder = core.decoder_network(env.observation_space.shape[0], src_env.observation_space.shape[0], 256, device,outputScale=0.5)
        self.encoder = core.decoder_network(src_env.observation_space.shape[0], env.observation_space.shape[0], 256, device)
        self.action_decoder = core.decoder_network(env.action_space.shape[0], src_env.action_space.shape[0], 256, device,outputScale=0.5)
        self.action_encoder = core.decoder_network(src_env.action_space.shape[0], env.action_space.shape[0], 256, device,outputScale=1.0)
        self.variationa_net = core.decoder_network(src_env.observation_space.shape[0], env.observation_space.shape[0], 64, device, variation=True)
        self.variationa_action_net = core.decoder_network(src_env.action_space.shape[0], env.action_space.shape[0], 64, device,outputScale=1.0, variation=True)

        self.log_ent_coef = torch.log(torch.ones(1, device=device) * 1.0).requires_grad_(True)
        self.ent_coef_optimizer = torch.optim.Adam([self.log_ent_coef], lr = 1e-3)
        self.target_entropy = -np.prod(env.action_space.shape).astype(np.float32)
        self.ent_coef = torch.exp(self.log_ent_coef.detach()).item()

        self.ac = core.MLPActorCritic(env.observation_space, env.action_space, source_model, device, **ac_kwargs)
        self.ac_targ = deepcopy(self.ac)
        self.ac_PI_Learner = deepcopy(self.ac)
        self.p = 0.0
        self.data = None
    
        self.enc_optimizer = Adam(itertools.chain(
                                (p for p in self.decoder.parameters()),
                                (p for p in self.encoder.parameters()),
                                (p for p in self.variationa_net.parameters()),
                                (p for p in self.action_decoder.parameters()),
                                (p for p in self.action_encoder.parameters()),
                                (p for p in self.variationa_action_net.parameters()),
                                (p for p in self.ac.q1.parameters() if p.requires_grad),
                                (p for p in self.ac.q2.parameters() if p.requires_grad),
                                (p for p in self.ac.pi.parameters() if p.requires_grad)
                                ), lr=lr) 
        
        self.indep_q_optimizer = Adam(itertools.chain(
                                (p for p in self.ac.q1.parameters() if p.requires_grad),
                                (p for p in self.ac.q2.parameters() if p.requires_grad),
                                ), lr=lr) 

        self.indep_pi_optimizer = Adam(
                                (p for p in self.ac.pi.parameters() if p.requires_grad)
                                , lr=lr) 
                                
 
    
        for p in self.ac_targ.parameters():
            p.requires_grad = False
        for p in self.ac_PI_Learner.parameters():
            p.requires_grad = False
        for p in self.state_flow.parameters():
            p.requires_grad = False
        # for p in self.action_flow.parameters():
        #     p.requires_grad = False

        self.replay_buffer = ReplayBuffer(obs_dim=env.observation_space.shape[0], act_dim=env.action_space.shape[0], size=replay_size, device=device)
        self.MI_losses = []
        self.MI_action_losses = []
        self.cycle_losses = []
        self.Q_losses = []
        self.pi_losses = []
        self.ent_Losses = []
        self.ent_coefs = []
        self.train()

    def calculate_p(self, time_steps):
        if time_steps <= self.total_steps*self.p_point1:
            return 0.0
        elif time_steps <= self.total_steps*(self.p_point2):
            return (time_steps - self.total_steps*self.p_point1)/(self.total_steps*self.p_point2-self.total_steps*self.p_point1)
        else:
            return 1.0 
    
    def get_action(self, o, deterministic=False):
        with torch.no_grad():
            o = torch.as_tensor(o, dtype=torch.float32, device=self.device).unsqueeze(dim=0)
            if self.p < 1.0:
                src_obs = self.decoder(o)
                src_obs = self.state_flow.g(src_obs.double())[0].float()[0]
            else:
                src_obs = None
        return self.ac.act(o, src_obs, self.p, deterministic)[0]

    def compute_entropy_loss(self, p):
        o = self.data['obs']
        if p == 1.0:
            src_o = None
        else:
            src_o = self.decoder(o)
            src_o = self.state_flow.g(src_o.double())[0].float()
            
        pi, logp_pi = self.ac.pi(o, src_o, self.p) 
        self.ent_coef = torch.exp(self.log_ent_coef.detach()).item()
        ent_coef_loss = -(self.log_ent_coef * (logp_pi.reshape(-1, 1) + self.target_entropy).detach()).mean()
        return ent_coef_loss
    
    def compute_MI_loss(self):
        o = self.data['obs']
        src_o = self.decoder(o)
        src_o = self.state_flow.g(src_o.double())[0].float()
        mu = self.variationa_net(src_o)
        MI_Loss = (self.variationa_net.logstd + ((o - mu)**2) / (2 * (torch.exp(self.variationa_net.logstd))**2)).mean()
        return MI_Loss
    
    def compute_MI_action_loss(self):
        a = self.data['act']
        src_a = self.action_decoder(a)
        src_a = self.action_flow.g(src_a.double())[0].float()
        src_a = a
        mu = self.variationa_action_net(src_a)
        MI_action_Loss = (self.variationa_action_net.logstd + ((a - mu)**2) / (2 * (torch.exp(self.variationa_action_net.logstd))**2)).mean()
        return MI_action_Loss
    
    def compute_cycle_loss(self):
        o = self.data['obs']
        src_o = self.decoder(o)
        src_o = self.state_flow.g(src_o.double())[0].float()
        recon_o = self.encoder(src_o)
        cycle_loss1 = ((recon_o - o)**2).mean()

        a = self.data['act']
        src_a = self.action_decoder(a)
        src_a = self.action_flow.g(src_a.double())[0].float()
        src_a = a
        recon_a = self.action_encoder(src_a)
        cycle_loss2 = ((recon_a - a)**2).mean()

        idx = np.random.choice(60000, 256, replace=False)
        so = torch.tensor(self.sourceState[idx],device=self.device)
        tar_o = self.encoder(so)
        recon_so = self.decoder(tar_o)
        recon_so = self.state_flow.g(recon_so.double())[0].float()
        cycle_loss3 = ((recon_so - so)**2).mean()

        sa = torch.tensor(self.sourceAction[idx],device=self.device)
        tar_a = self.action_encoder(sa)
        recon_sa = self.action_decoder(tar_a)
        recon_sa = self.action_flow.g(recon_sa.double())[0].float()
        recon_sa = tar_a
        cycle_loss4 = ((recon_sa - sa)**2).mean()

        return cycle_loss1 + cycle_loss2 + cycle_loss3 + cycle_loss4
    
    def compute_loss_pi(self, p):
        o = self.data['obs']
        if p == 1.0:
            src_o = None
        else:
            src_o = self.decoder(o)
            src_o = self.state_flow.g(src_o.double())[0].float()

        pi, logp_pi = self.ac.pi(o, src_o, p) 
        if p == 1.0:
            src_pi = None
        else:
            src_pi = self.action_decoder(pi)
            src_pi = self.action_flow.g(src_pi.double())[0].float()
            src_pi = pi

        q1_pi = self.ac_PI_Learner.q1(o, pi, src_o, src_pi, p) 
        q2_pi = self.ac_PI_Learner.q2(o, pi, src_o, src_pi, p) 
        q_pi = torch.min(q1_pi, q2_pi)

        loss_pi = (self.ent_coef * logp_pi - q_pi).mean()
        return loss_pi
    
    def compute_loss_q(self, p):
        o, a, r, o2, d = self.data['obs'], self.data['act'], self.data['rew'], self.data['obs2'], self.data['done']

        if p == 1.0:
            src_o = None
            src_a = None
        else:
            src_o = self.decoder(o)
            src_o = self.state_flow.g(src_o.double())[0].float()
            src_a = self.action_decoder(a)
            src_a = self.action_flow.g(src_a.double())[0].float()
            src_a = a

        q1 = self.ac.q1(o,a,src_o,src_a, p)
        q2 = self.ac.q2(o,a,src_o,src_a, p)

        with torch.no_grad():
            if p == 1.0:
                src_o2 = None
            else:
                src_o2 = self.decoder(o2)
                src_o2 = self.state_flow.g(src_o2.double())[0].float()

            a2, logp_a2 = self.ac.pi(o2,src_o2, p)
            if p == 1.0:
                src_a2 = None
            else:
                src_a2 = self.action_decoder(a2)
                src_a2 = self.action_flow.g(src_a2.double())[0].float()
                src_a2 = a2

            q1_pi_targ = self.ac_targ.q1(o2, a2,src_o2, src_a2, p)
            q2_pi_targ = self.ac_targ.q2(o2, a2,src_o2, src_a2, p)
            q_pi_targ = torch.min(q1_pi_targ, q2_pi_targ)
            backup = r + self.gamma * (1 - d) * (q_pi_targ - self.ent_coef * logp_a2)

        loss_q1 = ((q1 - backup)**2).mean()
        loss_q2 = ((q2 - backup)**2).mean()
        loss_q = loss_q1 + loss_q2
        return loss_q
    
    def update(self):
        self.ent_coef_optimizer.zero_grad()
        ent_coef_loss = self.compute_entropy_loss(self.p)
        ent_coef_loss.backward()
        self.ent_coef_optimizer.step()
        self.ent_Losses.append(ent_coef_loss.item())
        self.ent_coefs.append(self.ent_coef)

        if self.p < 1.0:
            self.enc_optimizer.zero_grad()
            cycle_loss = self.compute_cycle_loss()
            cycle_loss.backward()
            self.enc_optimizer.step()

            self.cycle_losses.append(cycle_loss.item())


            MI_loss = self.compute_MI_loss()
            MI_action_loss = self.compute_MI_action_loss()
            q_loss = self.compute_loss_q(self.p)
            loss_pi = self.compute_loss_pi(self.p)

            self.enc_optimizer.zero_grad()
            if self.p == 0.0:
                ( MI_loss  + q_loss + loss_pi).backward()
            else:
                ( MI_loss*0.5  + q_loss + loss_pi).backward()
            self.enc_optimizer.step()

            self.MI_losses.append(MI_loss.item())
            self.MI_action_losses.append(MI_action_loss.item())
            self.Q_losses.append(q_loss.item())
            self.pi_losses.append(loss_pi.item())
        else:
            self.indep_q_optimizer.zero_grad()
            q_loss = self.compute_loss_q(1.0)
            q_loss.backward()
            self.indep_q_optimizer.step()

            with torch.no_grad():
                for p, p_targ in zip(self.ac.parameters(), self.ac_PI_Learner.parameters()):
                    p_targ.data.mul_(0.0)
                    p_targ.data.add_(p.data)

            self.indep_pi_optimizer.zero_grad()
            loss_pi = self.compute_loss_pi(1.0)
            loss_pi.backward()
            self.indep_pi_optimizer.step()

            self.Q_losses.append(q_loss.item())
            self.pi_losses.append(loss_pi.item())
    
        with torch.no_grad():
            for p, p_targ in zip(self.ac.parameters(), self.ac_targ.parameters()):
                p_targ.data.mul_(self.polyak)
                p_targ.data.add_((1 - self.polyak) * p.data)
            
            if self.p < 1.0:
                for p, p_targ in zip(self.ac.parameters(), self.ac_PI_Learner.parameters()):
                    p_targ.data.mul_(0.0)
                    p_targ.data.add_(p.data)

    def test_agent(self, time_steps):
        test_rewards = 0.0
        test_steps = 0.0
        for _ in range(self.num_test_episodes):
            o, d = self.test_env.reset(), False
            while not d:
                o, r, d, _ = self.test_env.step(self.get_action(o, True))
                test_rewards += r
                test_steps += 1

        test_rewards, test_steps = test_rewards/self.num_test_episodes, test_steps/self.num_test_episodes 

        with open(self.log_folder + 'progress.csv', 'a', newline='') as csvfile:
            writer = csv.writer(csvfile)
            writer.writerow([time_steps, test_rewards, test_steps, np.mean(self.MI_losses),  np.mean(self.MI_action_losses),
                             self.p, np.mean(self.cycle_losses), np.mean(self.Q_losses), np.mean(self.pi_losses), np.mean(self.ent_Losses), np.mean(self.ent_coefs)])
        print(f"Training Steps          : {time_steps}")
        print("Testing Average Return  : {:.2f}".format(test_rewards))
        print("Testing Average Steps   : {:.2f}".format(test_steps))
        self.MI_losses.clear()
        self.cycle_losses.clear()
        self.Q_losses.clear()
        self.pi_losses.clear()
        self.MI_action_losses.clear()
    
    
    def train(self):
        o = self.env.reset()
        for t in tqdm(range(self.total_steps)):
            self.p = self.calculate_p(t)
            if t > self.start_steps:
                a = self.get_action(o)
            else:
                a = self.env.action_space.sample()

            o2, r, d, _ = self.env.step(a)
            self.replay_buffer.store(o, a, r, o2, d)
            o = o2

            if d:
                o = self.env.reset()

            if t >= self.start_steps:
                self.data = self.replay_buffer.sample_batch(self.batch_size)
                self.update()

            if (t+1) % self.test_steps == 0:
                self.test_agent(t+1)

def get_env(env_name):
    if env_name == 'Door' or env_name == 'Wipe':
        target_env = GymWrapper(
                suite.make(
                    env_name,
                    robots="UR5e",  # use Sawyer robot
                    use_camera_obs=False,  # do not use pixel observations
                    has_offscreen_renderer=False,  # not needed since not using pixel obs
                    has_renderer=False,  # make sure we can render to the screen
                    reward_shaping=True,  # use dense rewards
                    control_freq=20,  # control should happen fast enough so that simulation looks smooth
                )
        )
        eval_env = GymWrapper(
                suite.make(
                    env_name,
                    robots="UR5e",  # use Sawyer robot
                    use_camera_obs=False,  # do not use pixel observations
                    has_offscreen_renderer=False,  # not needed since not using pixel obs
                    has_renderer=False,  # make sure we can render to the screen
                    reward_shaping=True,  # use dense rewards
                    control_freq=20,  # control should happen fast enough so that simulation looks smooth
                )
        )
        src_env = GymWrapper(
                suite.make(
                    env_name,
                    robots="Panda",  # use Sawyer robot
                    use_camera_obs=False,  # do not use pixel observations
                    has_offscreen_renderer=False,  # not needed since not using pixel obs
                    has_renderer=False,  # make sure we can render to the screen
                    reward_shaping=True,  # use dense rewards
                    control_freq=20,  # control should happen fast enough so that simulation looks smooth
                )
        )
        return target_env, eval_env, src_env
    
    if env_name == 'HalfCheetah-v3' or env_name == 'Ant-v3':
        xml_file = None
        if env_name == 'HalfCheetah-v3':
            xml_file = "cheetah_target.xml"
        elif env_name == 'Ant-v3':
            xml_file = "ant_target.xml"
        return gym.make(env_name, xml_file=xml_file), gym.make(env_name, xml_file=xml_file), gym.make(env_name)
    
    raise ValueError(f"Environment {env_name} is not correct")
    
if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--env', type=str, default='HalfCheetah-v3')
    parser.add_argument('--device', type=str, default='cuda:0')
    parser.add_argument('--seed', type=int, default=1)
    parser.add_argument('--timesteps', type=float, default=5e5)
    parser.add_argument('--log_folder', type=str, default='')
    parser.add_argument('--source_file', type=str, default="")
    parser.add_argument('--flow_file', type=str, default="")
    parser.add_argument('--gamma', type=float, default=0.99)
    parser.add_argument('--p_point1', type=float, default=0.2)
    parser.add_argument('--p_point2', type=float, default=0.6)
    

    args = parser.parse_args()
    print(args)

    log_folder = args.log_folder + f"{args.p_point1}_{args.p_point2}/" + str(args.seed) + '/' 
    source_file = args.source_file + str(args.seed) + '/'

    source_model = SAC.load(args.source_file + str(args.seed) + '/' + "best_model",device=args.device)
    state_flow = RealNvp.load_module(args.flow_file + 'state/flow_seed' + str(args.seed) + ".pt").to(args.device)
    action_flow = RealNvp.load_module(args.flow_file + 'action/flow_seed'+ str(args.seed) + ".pt").to(args.device)

    target_env, eval_env, src_env = get_env(args.env)

    target_env.seed(seed=args.seed)
    target_env.action_space.seed(seed=args.seed)
    eval_env.seed(seed=args.seed)
    set_random_seed(seed = args.seed)

    os.makedirs(log_folder,exist_ok=True)
    print("#"*50 + "\n" + "Logging to file: " + log_folder + 'progress.csv' + '\n' + "#"*50)

    with open(log_folder + 'progress.csv', 'w', newline='') as csvfile:
        writer = csv.writer(csvfile)
        writer.writerow(['time/total_timesteps','eval/mean_reward','eval/mean_ep_length','MI_Loss','MI_action_Loss', 'p', 'Cycle_Loss', 'Q_Loss', 'PI_Loss', 'ent_Loss', 'ent_coef'])

    sac(target_env, eval_env, src_env, p_point1=args.p_point1, p_point2=args.p_point2, total_steps=int(args.timesteps), log_folder=log_folder, gamma=args.gamma,
        ac_kwargs=dict(hidden_sizes=256), source_model=source_model, device = args.device, sourceDataFile=source_file,seed=args.seed,
        state_flow = state_flow, action_flow=action_flow)
    print("#"*50 + "\n" + "Logging to file: " + log_folder  + 'progress.csv' + '\n' + "#"*50)
