import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF

import os
import sys
import cv2
import yaml
import random
import socket
import pyflex
import numpy as np
import matplotlib.pyplot as plt

from copy import deepcopy
from collections import OrderedDict
from dataset import Experience, QNetDataset, TrainGoalsDataset, TestGoalsDataset
from models import QNet

import softgym.envs.tshirt_descriptor as td
from softgym.envs.corl_baseline import GCFold

from metrics import metrics

IM_W = 566
W = 200

def arg(tag, default):
    HYPERS[tag] = type(default)((sys.argv[sys.argv.index(tag)+1])) if tag in sys.argv else default
    return HYPERS[tag]

def postprocess(heatmap, theta, scale):
    heatmap = rotate(heatmap, -theta).cpu().data.numpy()
    heatmap = unpad(heatmap, scale)
    heatmap = unscale_img(heatmap, scale)
    return heatmap

def rotate(img, theta):
    theta = -theta
    img = img.unsqueeze(0)
    affine_mat = np.array([[np.cos(theta), np.sin(theta), 0],[-np.sin(theta), np.cos(theta), 0]])
    affine_mat = torch.FloatTensor(affine_mat).unsqueeze(0).cuda()
    flow_grid = F.affine_grid(affine_mat, img.data.size())
    img = F.grid_sample(img, flow_grid, mode='nearest')
    return img.squeeze()

def scale_img(img, scale):
    scaled_img = cv2.resize(img, None, fx=scale, fy=scale)
    return scaled_img

def unscale_img(img, scale):
    rscale = 1.0/scale
    scaled_img = cv2.resize(img, None, fx=rscale, fy=rscale)
    return scaled_img

def pad(img):
    pad_amount = int(np.ceil((IM_W - img.shape[0])/2))
    padding = [(pad_amount,pad_amount),(pad_amount,pad_amount)]
    if len(img.shape) == 3:
        padding.append((0,0))
    padded = np.pad(img, padding, 'constant', constant_values=0)
    return padded

def unpad(img, scale):
    pad_amount = int(np.ceil((IM_W - np.ceil(W*scale))/2))
    unpadded = deepcopy(img[pad_amount:-pad_amount,pad_amount:-pad_amount])
    return unpadded

class Trainer():
    def __init__(self, cfg, model_name, env, data, test_data):
        self.cfg = cfg
        self.model_name = model_name
        self.env = env
        self.data = data
        self.test_data = test_data

        # Shorten hyperparameters
        self.W = self.cfg['W']
        self.num_rotations = self.cfg['num_rotations']
        self.num_scales = self.cfg['num_scales']
        self.resize_scales = self.cfg['resize_scales']
        self.rot_bin = 2*np.pi/self.num_rotations
        diag_len = float(self.W)*np.sqrt(2)*self.resize_scales[0][0]
        pad = int(np.ceil((diag_len - self.W)/2))
        self.im_w = self.W + 2*pad

        # Hostname specific vars
        hostname = socket.gethostname()
        if 'compute' in hostname:
            hostname = 'seuss'
        self.out_path = self.cfg[hostname]['out_path']

        self.init_dirs()
        self.init_model()

    def init_dirs(self):
        if self.cfg['debug'] and os.path.exists(f'{self.out_path}/output/{self.model_name}'):
            import shutil
            shutil.rmtree(f'{self.out_path}/output/{self.model_name}')
        os.mkdir(f'{self.out_path}/output/{self.model_name}')
        os.mkdir(f'{self.out_path}/output/{self.model_name}/images')
        os.mkdir(f'{self.out_path}/output/{self.model_name}/evals')
        os.mkdir(f'{self.out_path}/output/{self.model_name}/weights')
        with open(f'{self.out_path}/output/{self.model_name}/config.yaml', 'w') as f:
            yaml.dump(cfg, f)

    def init_model(self):
        in_channels = 6
        self.qnet = QNet(in_channels, self.im_w).cuda()
        self.tnet = deepcopy(self.qnet)
        self.opt = optim.Adam(self.qnet.parameters(), lr=self.cfg['learn_rate'], weight_decay=self.cfg['weight_decay'])

    def run(self):
        self.data.init_buffer()

        loss_list = []
        for step in range(self.cfg['episodes']+1):
            loss = self.train(step)
            print("loss: {}".format(loss))
            loss_list.append(loss)

            # update target network
            for o, t in zip(self.qnet.parameters(), self.tnet.parameters()):
                t.data = o.data.clone()*self.cfg['tau'] + t.data.clone()*(1-self.cfg['tau'])

            if self.cfg['eval_interval'] > 0 and step % self.cfg['eval_interval'] == 0:
                self.plot_loss(loss_list)

                torch.save(self.qnet, '{}/output/{}/weights/qnet_{}_rotations_{}_gamma_{}_update_{}.pt'.format(self.out_path, self.model_name, self.cfg['run_name'], self.num_rotations, self.cfg['gamma'], step))
                torch.save(self.tnet, '{}/output/{}/weights/tnet_{}_rotations_{}_gamma_{}_update_{}.pt'.format(self.out_path, self.model_name, self.cfg['run_name'], self.num_rotations, self.cfg['gamma'], step))
                torch.save(self.opt, '{}/output/{}/weights/opt_{}_rotations_{}_gamma_{}_update_{}.pt'.format(self.out_path, self.model_name, self.cfg['run_name'], self.num_rotations, self.cfg['gamma'], step))
                np.save("{}/output/{}/loss.npy".format(self.out_path, self.model_name), loss_list)

    def train(self, step):
        self.qnet.train()
        batch_idxs = random.sample(range(len(self.data)), self.cfg['batch'])
        
        losses = []
        for i, b_idx in enumerate(batch_idxs):
            b = self.data[b_idx]
            obs = b['obs']
            nobs = b['nobs']
            goal = b['goal']

            ## EVALUATE Qn
            heatmaps, angles_scale_idxs = self.get_heatmaps(nobs/255.0, goal/255.0, target=True)
            maps = torch.zeros([self.num_rotations*self.num_scales, self.W, self.W], 
                                dtype=torch.float32, device='cuda')
            vmaps = []
            obs_mask = (obs[:,:,0] > 0) | (obs[:,:,1] > 0) | (obs[:,:,2] > 0)
            for m_idx, (heatmap, angle_scale_idx) in enumerate(zip(heatmaps, angles_scale_idxs)):
                scale = self.resize_scales[angle_scale_idx[1]][0]
                
                size = (int(nobs.shape[0]*scale), int(nobs.shape[1]*scale))
                padding = int(np.ceil((self.im_w - size[0])/2))
                transform = torch.nn.Sequential(
                    transforms.CenterCrop(self.im_w - padding*2),
                    transforms.Resize((nobs.shape[0], nobs.shape[1]), interpolation=2))

                theta = np.degrees(angle_scale_idx[0]*self.rot_bin)
                m = transform(TF.rotate(heatmap, -theta))
                vmaps.append(deepcopy(m).detach().squeeze().cpu().numpy())
                maps[m_idx, :, :] = m
                maps[m_idx, ~obs_mask] = torch.finfo(torch.float32).min

            Qn = torch.max(maps)
            
            # Calculate target value
            y = b['rew'] + self.cfg['gamma']*Qn*(not b['done'])

            # Get Q from obs, act
            theta = np.degrees(b['act'][0]*self.rot_bin)
            scale = self.resize_scales[b['act'][1]][0]

            obs_t = torch.from_numpy(obs/255.0).float().cuda().permute(2, 0, 1)
            goal_t = torch.from_numpy(goal/255.0).float().cuda().permute(2, 0, 1)
            size = (int(obs.shape[0]*scale), int(obs.shape[1]*scale))
            padding = int(np.ceil((self.im_w - size[0])/2))
            transform = nn.Sequential(
                transforms.Resize(size, interpolation=2),
                transforms.Pad(padding, fill=0, padding_mode='constant'))
            obs_in = TF.rotate(transform(obs_t), theta).unsqueeze(0)
            goal_in = TF.rotate(transform(goal_t), theta).unsqueeze(0)
            x = torch.cat([obs_in, goal_in], dim=1)
            heatmap_unrot = self.qnet(x)
            heatmap = TF.rotate(heatmap_unrot, -theta).squeeze()

            mask = torch.zeros((self.W, self.W), dtype=torch.uint8, device='cuda')
            mask[b['act'][2], b['act'][3]] = 255
            mask_rot = transform(mask.unsqueeze(0)).squeeze()
            mask_rot[mask_rot != 0] = 255

            ind = [torch.floor(torch.argmax(mask_rot) / mask_rot.size()[1]).long(),
                   torch.argmax(mask_rot) % mask_rot.size()[1]]
        
            q = heatmap[ind[0],ind[1]]

            y = torch.FloatTensor([y]).cuda()
            loss = F.smooth_l1_loss(q, y)
            losses.append(loss)

            # Visualize training
            if step % self.cfg['eval_interval'] == 0 and i < 5:
                fig, ax = plt.subplots(1, 5, figsize=(16, 8), dpi=100)
                ax[0].set_title("obs")
                viz_obs = self.draw_arrow(obs, b['act'])
                ax[0].imshow(viz_obs)
                ax[1].set_title("nobs")
                ax[1].imshow(nobs)
                ax[2].set_title("goal")
                ax[2].imshow(goal)
                ax[3].set_title(f"q: {q.detach().cpu().numpy():.3f} y: {y[0].detach().cpu().numpy():.3f}")
                ax[3].imshow(heatmap.detach().cpu().numpy())
                ax[3].scatter(ind[1].detach().cpu().numpy(), ind[0].detach().cpu().numpy(), alpha=0.5, c='w')
                ax[4].set_title(f"nobs max: {Qn.detach().cpu().numpy():.3f}\nmin: {np.min(vmaps):.3f}")
                viz_img = self.render_heatmaps(nobs, vmaps)
                viz_img = cv2.cvtColor(viz_img, cv2.COLOR_BGR2RGB)
                ax[4].imshow(viz_img)
                plt.tight_layout()
                # plt.show()
                plt.savefig(f'{self.out_path}/output/{self.model_name}/images/{step}_{i}.png')
                plt.close()

        loss = torch.mean(torch.stack(losses))

        self.opt.zero_grad()
        loss.backward()
        self.opt.step()

        return loss.item()

    def get_heatmaps(self, img, goal, target=False):
        with torch.no_grad():
            angles_scale_idxs = []
            imgs = torch.zeros([self.num_rotations*self.num_scales, 3, self.im_w, self.im_w], 
                                dtype=torch.float32, device='cuda')
            goals = torch.zeros([self.num_rotations*self.num_scales, 3, self.im_w, self.im_w], 
                                dtype=torch.float32, device='cuda')
            img_t = torch.from_numpy(img).float().cuda().permute(2, 0, 1)
            goal_t = torch.from_numpy(goal).float().cuda().permute(2, 0, 1)

            for rotate_idx in range(self.num_rotations):
                for scale_idx in range(self.num_scales):
                    scale = self.resize_scales[scale_idx][0]
                    i = rotate_idx*self.num_scales + scale_idx
                        
                    size = (int(img.shape[0]*scale), int(img.shape[1]*scale))
                    padding = int(np.ceil((self.im_w - size[0])/2))
                    transform = nn.Sequential(
                        transforms.Resize(size, interpolation=2),
                        transforms.Pad(padding, fill=0, padding_mode='constant'))
                        
                    imgs[i, :, :, :] = TF.rotate(transform(img_t), np.degrees(rotate_idx*self.rot_bin))
                    goals[i, :, :, :] = TF.rotate(transform(goal_t), np.degrees(rotate_idx*self.rot_bin))
                    
                    angles_scale_idxs.append((rotate_idx, scale_idx))

            x = torch.cat([imgs, goals], dim=1)
            if not target:
                heatmaps = self.qnet(x)
            else:
                heatmaps = self.tnet(x)
            return heatmaps, angles_scale_idxs

    def select_action(self, obs, goal, rgb, viz=True, random=False):
        mask = (obs[:,:,0] > 0) | (obs[:,:,1] > 0) | (obs[:,:,2] > 0) != 0
        with torch.no_grad():
            self.qnet.eval()
            heatmaps, angles_scale_idxs = self.get_heatmaps(obs/255.0, goal/255.0)
            maps = []
            vmaps = []
            max_idxs = []
            for heatmap, angle_scale_idx in zip(heatmaps, angles_scale_idxs):
                scale = self.resize_scales[angle_scale_idx[1]][0]
                m = postprocess(heatmap, angle_scale_idx[0]*self.rot_bin, scale)
                vmaps.append(deepcopy(m))
                m[mask==0] = -10000000
                maps.append(m)
            maps = np.array(maps)

            if random:
                uv = td.random_sample_from_masked_image(mask, 1)
                u,v = uv[0][0], uv[1][0] # TODO check this
                action = [np.random.randint(self.num_rotations), np.random.randint(self.num_scales), u, v]
            else:
                act = list(np.unravel_index(np.argmax(maps),maps.shape))
                rotate_max_idx = int(np.floor(act[0]/self.num_scales))
                scale_max_idx  = act[0]%self.num_scales
                action = [rotate_max_idx, scale_max_idx, act[1], act[2]]

            if viz: 
                viz_img = self.render_heatmaps(obs, vmaps, action)

            disp_img = self.draw_arrow(rgb.astype(np.uint8), action)
        return action, disp_img, viz_img

    def draw_arrow(self, obs, action):
        viz_obs = deepcopy(obs)
        theta = np.deg2rad((360.0/self.num_rotations) * action[0])
        fold_dist = self.cfg['mid_fold_dist'] * self.resize_scales[action[1]][1]
        x = action[3]
        y = action[2]

        arrow_x = int(x + fold_dist*np.cos(theta))
        arrow_y = int(y + fold_dist*np.sin(theta))

        color = (0,0,200)
        cv2.circle(viz_obs, (x, y), 6, color, 2)
        cv2.arrowedLine(viz_obs, (x, y), (arrow_x, arrow_y), color, 2, tipLength=0.1)
        return viz_obs

    def render_heatmaps(self, color, heatmaps, max_idx=[-1, -1]):
        img = None
        maxh = np.max(heatmaps) if np.max(heatmaps) > 1.5 else 1.5
        minh = np.min(heatmaps) if np.min(heatmaps) < 0 else 0
        heatmaps = (np.array(heatmaps) - minh) / (maxh - minh)

        i = 0
        for r in range(self.num_rotations):
            row = None
            for s in range(self.num_scales):
                m = heatmaps[i]
                m = cv2.applyColorMap((m*255).astype(np.uint8), cv2.COLORMAP_JET)
                m = (0.5*m + 0.5*(color*255)).astype(np.uint8)

                theta = r*self.rot_bin
                fold_dist = 0.1*((s)+1.0) # because of flip
                line_start = np.array([self.W/2, self.W/2])
                line_end = line_start + (W*(fold_dist*np.array([np.sin(theta), np.cos(theta)]))).astype(np.int8)

                m = cv2.arrowedLine(m, (int(line_start[1]),int(line_start[0])), (int(line_end[1]),int(line_end[0])), (100,100,100), 3, tipLength=0.3)
                if r == max_idx[0] and s == max_idx[1]:
                    m = cv2.circle(m, (int(max_idx[3]), int(max_idx[2])), 7, (255,255,255), 2)
                m = cv2.resize(m, (100,100))
                if row is None:
                    row = m
                else:
                    row = np.concatenate((row, m), axis=0)
                i+=1
            if img is None:
                img = row
            else:
                img = np.concatenate((img,row), axis=1)
        return img

    def evaluate(self, step):
        """Evaluate the current weights on the goals
        """
        print(f"evaluating weight {step}")
        for i in range(len(self.test_data)):
            start_pos, goal, goal_pos = self.test_data[i]

            # reset env
            obs = env.reset(given_goal=goal)
            pyflex.set_positions(start_pos)
            pyflex.step()
            obs = env._get_obs()
            done = False
    
            vizs = []
            disps = []
            distances = []
            while not done: 
                act, disp_img, viz_img = self.select_action(obs['depth'], obs["goal"], rgb=obs["color"])
                
                # step env
                nobs, rew, done, _ = env.step(act, on_table=cfg['on_table'])
                nobs_mask = (nobs['depth'][:,:,0] > 0) | (nobs['depth'][:,:,1] > 0) | (nobs['depth'][:,:,2] > 0) != 0
                nobs['depth'][nobs_mask == False, :] = 0
            
                # Calc particle metrics
                pos = pyflex.get_positions().reshape(-1, 4)[:,:3]
                pos_metric = td.calc_metric_abs(pos, goal_pos)
                distances.append(pos_metric[1])

                vizs.append(viz_img)
                disps.append(disp_img)

                obs = deepcopy(nobs)

            disps.append(obs["color"]) # final obs

            start_name, goal_name = self.cfg['eval_combos'][i]
            name = f'{start_name}-{goal_name}'
            cv2.imwrite(f'{self.out_path}/output/{self.model_name}/evals/{name}_disp_step_{step}.png', np.vstack(disps))
            cv2.imwrite(f'{self.out_path}/output/{self.model_name}/evals/{name}_viz_step_{step}.png', np.vstack(vizs))

            # add the best metrics for episodes
            min_dist = min(distances)
            min_step = np.argmin(distances)
            with open(f'{self.out_path}/output/{self.model_name}/evals/metrics.csv', 'a+') as f:
                f.write(f'{step},{name},{min_dist},{min_step}\n')

    def plot_loss(self, loss_list):
        plt.figure(figsize=(12, 4))
        plt.title('average loss')
        plt.plot(range(len(loss_list)), loss_list)
        plt.savefig(f'{self.out_path}/output/{self.model_name}/loss.png')
        plt.close()

if __name__ == '__main__':
    HYPERS = OrderedDict()
    ENV_TYPE = arg('env', 'towel') # towel # tshirt

    with open('config.yaml') as f:
        cfg = yaml.load(f, Loader=yaml.FullLoader)

    seed = cfg['seed']
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

    if ENV_TYPE == 'towel':
        env = GCFold(use_depth=True, 
                    use_cached_states=False, 
                    horizon=5, 
                    action_repeat=1, 
                    headless=cfg['headless'])
    else:
        raise NotImplementedError
        
    model_name = f'{ENV_TYPE}_{cfg["obs_mode"]}' + \
                f'_flow{cfg["flow"]}' + \
                f'_itr{cfg["episodes"]}' + \
                f'_lr{cfg["learn_rate"]}' + \
                f'_gamma{cfg["gamma"]}' + \
                f'_seed{cfg["seed"]}' + \
                f'{"_online" if cfg["online"] else "_offline"}' + \
                f'{cfg["model_suffix"]}'

    data = QNetDataset(cfg, model_name, env.camera_params)
    eval_data = None

    t = Trainer(cfg, model_name, env, data, eval_data)
    t.run()
