import os
import argparse
#import pybullet_envs
import numpy as np
from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize
from stable_baselines3.common.utils import set_random_seed
from stable_baselines3 import SAC
from stable_baselines3.common.monitor import Monitor
import warnings
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from itertools import count
from collections import deque
import torch.optim as optim
import random
from torch.distributions import Categorical
from torch.autograd import Variable
from torch.nn.parameter import Parameter
import gym
import warnings
from tqdm import tqdm
import shelve
import csv

def seed_everything(seed):
    random.seed(seed)
    np.random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def put_transition(buffer, *transition):
    buffer.append(transition)

class env_model(nn.Module):
    def __init__(self, args, state_dim, action_dim):
        super(env_model, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(state_dim + action_dim, 32),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(32, 64),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(64, 128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(128, 32),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(32, state_dim)
        )
        self.device = args.device
        self.state_dim = state_dim
        self.action_dim = action_dim

    def forward(self, state, action):
        next_state = self.model(torch.cat((state, action), 1))
        return next_state
   
class state_alignment(nn.Module):
    def __init__(self, args, source_state_dim, target_state_dim):
        super(state_alignment, self).__init__()
        self.resnet = models.resnet18(weights=None)
        self.resnet.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.model = nn.Sequential(
            nn.Linear(1000, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 64),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(64, 32),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(32, source_state_dim),
        )
        self.device = args.device
        self.target_state_dim = target_state_dim

    def forward(self, state):
        state = state.view(-1, 1, 1, args.target_state_dim)
        source_state = self.model(self.resnet(state))
        return source_state
    
class action_alignment(nn.Module):
    def __init__(self, args, source_state_dim, source_action_dim, target_action_dim):
        super(action_alignment, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(source_state_dim + source_action_dim, 32),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(32, 64),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(64, 128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(128, 32),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(32, target_action_dim)
        )
        self.device = args.device
        self.source_state_dim = source_state_dim
        self.source_action_dim = source_action_dim

    def forward(self, state, action):
        target_action = self.model(torch.cat((state, action), 1))
        return target_action
    
class source_discriminator(nn.Module):
    def __init__(self, args, state_dim):
        super(source_discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(state_dim, 32),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(32, 64),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(64, 128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(128, 32),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(32, 1),
            nn.Sigmoid()
        )
        self.device = args.device
        self.state_dim = state_dim

    def forward(self, state):
        x = self.model(state)
        return x
    
class target_discriminator(nn.Module):
    def __init__(self, args, action_dim):
        super(target_discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(action_dim, 32),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(32, 64),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(64, 128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(128, 32),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(32, 1),
            nn.Sigmoid()
        )
        self.device = args.device
        self.action_dim = action_dim

    def forward(self, action):
        x = self.model(action)
        return x


parser = argparse.ArgumentParser()
parser.add_argument("--seed", type=int, nargs='?', default=2)
parser.add_argument("--env", type=str, nargs='?', default="HalfCheetah-v3")
parser.add_argument("--total_step", type=int, nargs='?', default=1000000)
parser.add_argument("--folder", type=str, nargs='?', default="/home")
parser.add_argument("--source_state_dim", type=int, nargs='?', default=17)
parser.add_argument("--source_action_dim", type=int, nargs='?', default=23)
parser.add_argument("--target_state_dim", type=int, nargs='?', default=6)
parser.add_argument("--target_action_dim", type=int, nargs='?', default=9)
parser.add_argument("--device", type=str, nargs='?', default="cuda:3")
parser.add_argument("--env_lr", type=float, nargs='?', default=1e-4)
parser.add_argument("--D_lr", type=float, nargs='?', default=1e-3)
parser.add_argument("--G_lr", type=float, nargs='?', default=1e-4)
parser.add_argument("--source_batch_size", type=int, nargs='?', default=256)
parser.add_argument("--target_batch_size", type=int, nargs='?', default=32)
parser.add_argument("--xml_file", type=str, nargs='?', default='cheetah_target.xml')
args = parser.parse_args()

log_folder = "./log_policy/" + str(args.env) + "/seed" + str(args.seed) + '/'
source_folder = args.folder + '/QAvatar/source/stable-baselines3-1.7.0/logs/source/seed' + str(args.seed) + '/'
seed_everything(args.seed)
vec_env = DummyVecEnv([lambda: gym.make(args.env, xml_file = args.folder + '/assets/' + args.xml_file)])
eval_vec_env = DummyVecEnv([lambda: Monitor(gym.make(args.env, xml_file = args.folder + '/assets/' + args.xml_file))])
vec_env.seed(seed=args.seed)
vec_env.action_space.seed(seed=args.seed)
eval_vec_env.seed(seed=args.seed)
set_random_seed(seed = args.seed)

target_buffer = deque(maxlen=1000000)
source_buffer = deque()
file = shelve.open(args.folder + '/QAvatar/source/stable-baselines3-1.7.0/buffer/buffer_seed' + str(args.seed))
source_buffer = file['buffer']
source_model = SAC.load(source_folder + "best_model")

env_predict = env_model(args, args.target_state_dim, args.target_action_dim).to(args.device)
env_predict_optimizer = torch.optim.Adam(env_predict.parameters(), lr = args.env_lr)

G1 = state_alignment(args, args.source_state_dim, args.target_state_dim).to(args.device)
G2 = action_alignment(args, args.source_state_dim, args.source_action_dim, args.target_action_dim).to(args.device)
Dsource = source_discriminator(args, args.source_state_dim).to(args.device)
Dtarget = target_discriminator(args, args.target_action_dim).to(args.device)

G1_optimizer = torch.optim.Adam(G1.parameters(), lr = args.G_lr)
G2_optimizer = torch.optim.Adam(G2.parameters(), lr = args.G_lr)
Dsource_optimizer = torch.optim.Adam(Dsource.parameters(), lr = args.D_lr)
Dtarget_optimizer = torch.optim.Adam(Dtarget.parameters(), lr = args.D_lr)

episode = 1
step = 1

with open(log_folder + 'result.csv', 'w', newline='') as csvfile:
  writer = csv.writer(csvfile)
  writer.writerow(['steps','reward','L_gan','L_3', 'L_G'])
print("start experiment")
L_gan_all = []
L_3_all = []
L_G_all = []
max_test_reward = -100000000.0

while(step <= args.total_step):
    obs = vec_env.reset()
    dones = False
    total_reward = 0.0

    while not dones:
        obs_state = torch.tensor(obs[0], dtype=torch.float).view(-1, 1, 1, args.target_state_dim).to(args.device)
        G1.eval()
        G1_output = G1(obs_state)
        G1.train()
        action, _ = source_model.predict(G1_output.detach().cpu().numpy())
        action = torch.tensor(action, dtype=torch.float).view(-1, args.source_action_dim).to(args.device)
        action = G2(G1_output.view(-1, args.source_state_dim).to(args.device), action).detach().cpu().numpy()
        obs_next, rewards, dones, info = vec_env.step(action)
        put_transition(target_buffer, obs[0], action, rewards, obs_next[0], dones)
        obs = obs_next
        total_reward += rewards
        if(step % 500 == 0):
            total_reward_test_all = []
            G1.eval()
            G2.eval()
            for test_times in range(5):
                obs_test = eval_vec_env.reset()
                dones_test = False
                total_reward_test = 0.0
                while not dones_test:
                    test_state = torch.tensor(obs_test[0], dtype=torch.float).view(-1, 1, 1, args.target_state_dim).to(args.device)
                    test_G1_output = G1(test_state)
                    
                    action_output, _ = source_model.predict(test_G1_output.detach().cpu().numpy())
                    action_output = torch.tensor(action_output, dtype=torch.float).view(-1, args.source_action_dim).to(args.device)
                    action_test = G2(test_G1_output.view(-1, args.source_state_dim).to(args.device), action_output).detach().cpu().numpy()
                    obs_next_test, rewards_test, dones_test, info_test = eval_vec_env.step(action_test)
                    obs_test = obs_next_test
                    total_reward_test += rewards_test
                total_reward_test_all.append(total_reward_test)
            tmp = []
            tmp.append(step)
            tmp.append(np.mean(total_reward_test_all))
            if(len(L_gan_all)!=0):
                tmp.append(np.mean(L_gan_all))
                tmp.append(np.mean(L_3_all))
                tmp.append(np.mean(L_G_all))
                print(f"step: {step}, reward: {np.mean(total_reward_test_all)}, L_gan: {np.mean(L_gan_all)}, L_G: {np.mean(L_G_all)}")
            else:
                print(f"step: {step}, reward: {np.mean(total_reward_test_all)}")
            with open(log_folder + 'result.csv', 'a', newline='') as csvfile:
                writer = csv.writer(csvfile)
                writer.writerow(tmp)
            L_gan_all = []
            L_3_all = []
            L_G_all = []
            if(np.mean(total_reward_test_all) > max_test_reward):
                torch.save(G1.state_dict(), log_folder + 'G1.pth')
                torch.save(G2.state_dict(), log_folder + 'G2.pth')
            G1.train()
            G2.train()
            
            
        if(len(target_buffer) == args.target_batch_size):
            for i in tqdm(range(90000)):
                sample = random.sample(target_buffer, args.target_batch_size)
                state, action, reward, next_state, done = zip(*sample)
                state = torch.tensor(np.array(state), dtype=torch.float).view(-1,args.target_state_dim).to(args.device)
                action = torch.tensor(np.array(action), dtype=torch.long).view(-1,args.target_action_dim).to(args.device)
                reward = torch.tensor(np.array(reward), dtype=torch.float).to(args.device)
                next_state = torch.tensor(np.array(next_state), dtype=torch.float).view(-1,args.target_state_dim).to(args.device)
                done = torch.tensor(np.array(done), dtype=torch.long).view(-1,1).to(args.device)
                loss_fn = nn.L1Loss()
                predict_next_state = env_predict(state, action)
                loss = loss_fn(predict_next_state, next_state)
                env_predict_optimizer.zero_grad()
                loss.backward()
                env_predict_optimizer.step()
        elif(len(target_buffer) > args.target_batch_size):
            sample = random.sample(target_buffer, args.target_batch_size)
            state, action, reward, next_state, done = zip(*sample)
            state = torch.tensor(np.array(state), dtype=torch.float).view(-1,args.target_state_dim).to(args.device)
            action = torch.tensor(np.array(action), dtype=torch.long).view(-1,args.target_action_dim).to(args.device)
            reward = torch.tensor(np.array(reward), dtype=torch.float).to(args.device)
            next_state = torch.tensor(np.array(next_state), dtype=torch.float).view(-1,args.target_state_dim).to(args.device)
            done = torch.tensor(np.array(done), dtype=torch.long).view(-1,1).to(args.device)
            loss_fn = nn.L1Loss()
            predict_next_state = env_predict(state, action)
            loss = loss_fn(predict_next_state, next_state)
            env_predict_optimizer.zero_grad()
            loss.backward()
            env_predict_optimizer.step()

        if(len(target_buffer) > args.target_batch_size):
            if(np.ceil(step/7000)%2==0):
                sample = random.sample(target_buffer, args.target_batch_size)
                state, action, reward, next_state, done = zip(*sample)
                state = torch.tensor(np.array(state), dtype=torch.float).view(-1,args.target_state_dim).to(args.device)
                action = torch.tensor(np.array(action), dtype=torch.float).view(-1,args.target_action_dim).to(args.device)
                reward = torch.tensor(np.array(reward), dtype=torch.float).to(args.device)
                next_state = torch.tensor(np.array(next_state), dtype=torch.float).view(-1,args.target_state_dim).to(args.device)
                done = torch.tensor(np.array(done), dtype=torch.long).view(-1,1).to(args.device)

                source_sample = random.sample(source_buffer, args.source_batch_size)
                source_state, source_action, source_reward, source_next_state, source_done = zip(*source_sample)
                source_state = torch.tensor(np.array(source_state), dtype=torch.float).view(-1,args.source_state_dim).to(args.device)
                source_action = torch.tensor(np.array(source_action), dtype=torch.float).view(-1,args.source_action_dim).to(args.device)
                source_reward = torch.tensor(np.array(source_reward), dtype=torch.float).to(args.device)
                source_next_state = torch.tensor(np.array(source_next_state), dtype=torch.float).view(-1,args.source_state_dim).to(args.device)
                source_done = torch.tensor(np.array(source_done), dtype=torch.long).view(-1,1).to(args.device)

                BCEloss = nn.BCELoss()
                loss_fn = nn.L1Loss()
                G1_output = G1(state)
                action_output, _ = source_model.predict(G1_output.detach().cpu().numpy())
                action_output = torch.tensor(action_output, dtype=torch.float).view(-1, args.source_action_dim).to(args.device)
                L3 = loss_fn(G1(env_predict(state, G2(G1(state), action_output))), G1(next_state))
                L_G1 = BCEloss(Dsource(G1_output), torch.ones(args.target_batch_size, 1).to(args.device)) + 3 * L3
                G1_optimizer.zero_grad()
                L_G1.backward()
                G1_optimizer.step()

                G1_output = G1(state)
                L_gan = (BCEloss(Dsource(G1_output.detach()), torch.zeros(args.target_batch_size, 1).to(args.device)) + BCEloss(Dsource(source_state), torch.ones(args.source_batch_size, 1).to(args.device)))
                Dsource_optimizer.zero_grad()
                L_gan.backward()
                Dsource_optimizer.step()
                L_gan_all.append(L_gan.item())
                L_3_all.append(L3.item())
                L_G_all.append(L_G1.item())
            else:
                sample = random.sample(target_buffer, args.target_batch_size)
                state, action, reward, next_state, done = zip(*sample)
                state = torch.tensor(np.array(state), dtype=torch.float).view(-1,args.target_state_dim).to(args.device)
                action = torch.tensor(np.array(action), dtype=torch.float).view(-1,args.target_action_dim).to(args.device)
                reward = torch.tensor(np.array(reward), dtype=torch.float).to(args.device)
                next_state = torch.tensor(np.array(next_state), dtype=torch.float).view(-1,args.target_state_dim).to(args.device)
                done = torch.tensor(np.array(done), dtype=torch.long).view(-1,1).to(args.device)

                source_sample = random.sample(source_buffer, args.source_batch_size)
                
                source_state, source_action, source_reward, source_next_state, source_done = zip(*source_sample)
                source_state = torch.tensor(np.array(source_state), dtype=torch.float).view(-1,args.source_state_dim).to(args.device)
                source_action = torch.tensor(np.array(source_action), dtype=torch.float).view(-1,args.source_action_dim).to(args.device)
                source_reward = torch.tensor(np.array(source_reward), dtype=torch.float).to(args.device)
                source_next_state = torch.tensor(np.array(source_next_state), dtype=torch.float).view(-1,args.source_state_dim).to(args.device)
                source_done = torch.tensor(np.array(source_done), dtype=torch.long).view(-1,1).to(args.device)

                BCEloss = nn.BCELoss()
                loss_fn = nn.L1Loss()
                
                G1_output = G1(state)
                action_output, _ = source_model.predict(G1_output.detach().cpu().numpy())
                action_output = torch.tensor(action_output, dtype=torch.float).view(-1, args.source_action_dim).to(args.device)
                G2_output = G2(G1(state), action_output)
                L3 = loss_fn(G1(env_predict(state, G2(G1(state), action_output))), G1(next_state))
                L_G2 = BCEloss(Dtarget(G2_output), torch.ones(args.target_batch_size, 1).to(args.device)) + 3 * L3
                G2_optimizer.zero_grad()
                L_G2.backward()
                G2_optimizer.step()

                G1_output = G1(state)
                action_output, _ = source_model.predict(G1_output.detach().cpu().numpy())
                action_output = torch.tensor(action_output, dtype=torch.float).view(-1, args.source_action_dim).to(args.device)
                G2_output = G2(G1(state), action_output)
                L_gan = BCEloss(Dtarget(G2_output.detach()), torch.zeros(args.target_batch_size, 1).to(args.device)) + BCEloss(Dtarget(action), torch.ones(args.target_batch_size, 1).to(args.device))
                Dtarget_optimizer.zero_grad()
                L_gan.backward()
                Dtarget_optimizer.step()

                L_gan_all.append(L_gan.item())
                L_3_all.append(L3.item())
                L_G_all.append(L_G2.item())
        step += 1
        if(step > args.total_step):
            break
    episode += 1
    if(step > args.total_step):
        break
