import numpy as np
import matplotlib.pyplot as plt
from models import FlowNetPickSplit
import yaml
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms.functional as TF
from collections import namedtuple
from dataset import Goals
from torch.utils.data import DataLoader
import cv2
from softgym.envs.bimanual_env import BimanualEnv
from softgym.envs.bimanual_tshirt import BimanualTshirtEnv
import sys
sys.path.append('/home/exx/projects/softagent/descriptors_softgym_baseline')
from utils import remove_dups
import pyflex
import os
import time 
from copy import deepcopy
import softgym.envs.tshirt_descriptor as td
import argparse
from omegaconf import OmegaConf
from flow_utils import plot_flow


Experience = namedtuple('Experience', ('obs', 'goal', 'act', 'rew', 'nobs', 'done'))

def ccw(A,B,C):
    return (C[1] - A[1]) * (B[0] - A[0]) > (B[1] - A[1]) * (C[0] - A[0])

def intersect(A,B,C,D):
    return ccw(A,C,D) != ccw(B,C,D) and ccw(A,B,C) != ccw(A,B,D)


class EnvRollout(object):
    def __init__(self, cfgs):
        self.cfgs = cfgs

        seed = cfgs['seed']
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)

        if cfgs['cloth_type'] == 'towel':
            self.env = BimanualEnv(use_depth=cfgs['img_type'] == 'depth',
                    use_cached_states=False,
                    horizon=cfgs['horizon'],
                    use_desc=False,
                    action_repeat=1,
                    headless=cfgs['headless'],
                    shape=cfgs['shape'])
        elif cfgs['cloth_type'] == 'tshirt':
            self.env = BimanualTshirtEnv(use_depth=cfgs['img_type'] == 'depth',
                    use_cached_states=False,
                    use_desc=False,
                    horizon=cfgs['horizon'],
                    action_repeat=1,
                    headless=cfgs['headless'])

        self.load_model()

        goal_data = Goals(self.cfgs, mode=cfgs['goals'])
        self.goal_loader = DataLoader(goal_data, batch_size=1, shuffle=False, num_workers=0)
        self.save_dir = f'{self.cfgs["output_dir"]}/{self.cfgs["run_name"]}/rollout_{self.cfgs["goals"]}_{self.cfgs["input_mode"]}_{self.model.cfg.predictplace}_{"nominl" if "nominl" in self.cfgs["run_name"] else "minl"}_repeat{self.cfgs["repeat"]}_skip{self.cfgs["skip"]}'
        if not os.path.exists(self.save_dir):
            os.mkdir(self.save_dir)
            os.mkdir(f'{self.save_dir}/unscaled')

        self.ups = nn.Upsample(size=(200,200), mode='bilinear', align_corners=True)

    def get_corner_particles(self):
        state = self.env.reset(given_goal=None, given_goal_pos=None)
        uv = td.particle_uv_pos(self.env.camera_params,None)
        uv[:,[1,0]]=uv[:,[0,1]]
        uv = (uv/719) * 199

        # corner 1
        dists = np.linalg.norm((uv - np.array([25,25])),axis=1)
        c1 = dists.argmin()

        # corner 2
        dists = np.linalg.norm((uv - np.array([25,175])),axis=1)
        c2 = dists.argmin()

        # corner 3
        dists = np.linalg.norm((uv - np.array([175,175])),axis=1)
        c3 = dists.argmin()

        # corner 4
        dists = np.linalg.norm((uv - np.array([175,25])),axis=1)
        c4 = dists.argmin()

        return c1,c2,c3,c4

    def get_particle_uv(self,idx):
        uv = td.particle_uv_pos(self.env.camera_params,None)
        uv[:,[1,0]]=uv[:,[0,1]]
        uv = (uv/719) * 199
        u,v = uv[idx]
        return u,v


    def get_true_corner_mask(self, clothmask, r=4):
        true_corners = np.zeros((200,200))
        for c in self.corners:
            b,a = self.get_particle_uv(c)
            h,w = true_corners.shape
            y,x = np.ogrid[-a:h-a, -b:w-b]
            cmask = x*x + y*y <= r*r
            true_corners[cmask] = 1

        true_corners = true_corners * clothmask
        return true_corners


    def load_model(self):
        run_name = self.cfgs['run_name']
        load_iter = self.cfgs['load_iter']

        # load joint model
        ckpt_path = f"{self.cfgs['output_dir']}/{run_name}/default_default/0_0/checkpoints/epoch={load_iter}.ckpt"
        cfg_path = f"{self.cfgs['output_dir']}/{run_name}/.hydra/config.yaml"
        with open(cfg_path, 'r') as f:
            cfg = OmegaConf.create(yaml.load(f, Loader=yaml.FullLoader))

        ckpt = torch.load(ckpt_path)
        self.model = FlowNetPickSplit(cfg.netconf).cuda()   
        self.model.load_state_dict(ckpt['state_dict'])
        self.model.eval()

    def get_obs_tensor(self):
        _, depth = self.env.get_rgbd()
        depth = cv2.resize(depth, (200, 200))
        if self.cfgs['shape'] == 'large':
            depth[depth > 0.] -= 0.325
        depth = torch.FloatTensor(depth)
        depth = depth.unsqueeze(0).unsqueeze(0)
        return depth


    def get_rgbd(self):
        rgbd = pyflex.render_sensor()
        rgbd = np.array(rgbd).reshape(self.env.camera_height, self.env.camera_width, 4)
        rgbd = rgbd[::-1, :, :]
        rgb = rgbd[:, :, :3]
        img = self.env.get_image(self.env.camera_height, self.env.camera_width)
        depth = rgbd[:, :, 3]
        mask = depth > 0
        return img, depth, mask


    def calc_gt_flow(self,uv_n_f):
        _, depth_o = self.env.get_rgbd()
        coords_o = pyflex.get_positions().reshape(-1, 4)
        uv_o_f = td.particle_uv_pos(self.env.camera_params,None)

        # Remove occlusions
        uv_o, _ = remove_dups(self.env.camera_params, uv_o_f, coords_o, depth_o, zthresh=0.001)

        # Remove out of bounds
        uv_o[uv_o < 0] = float('NaN')
        uv_o[uv_o >= 720] = float('NaN')

        # Get flow image
        flow = self.flw.get_image(uv_o, uv_n_f)

        flow = flow.transpose([2,0,1])
        flow = torch.FloatTensor(flow).unsqueeze(0)

        return flow


    def run(self):
        """Main eval loop
        """
        actions = []
        total_metrics = []
        normed_metrics = []
        times = []

        for i, b in enumerate(self.goal_loader):
            start, goals, coords_o, coords_n = b

            state = self.env.reset(given_goal=None, given_goal_pos=None)
            pyflex.set_positions(coords_o[0])
            pyflex.step()
            done = False

            for j in range(goals.shape[1]):

                repeat = self.cfgs['repeat']

                for l in range(repeat):
                    goal = goals[0,j].unsqueeze(0)
                    self.env.render(mode='rgb_array')
                    obs = self.get_obs_tensor()

                    st = time.time()
                    action, unmasked_pred, info = self.get_action(obs, goal)
                    at = time.time() - st
                    times.append(at)

                    actions.append(action)
                    next_state, reward, done, _ = self.env.step(action, pickplace=True, on_table=self.cfgs['on_table'])
                    self.env.render(mode='rgb_array')
                    
                    img = cv2.cvtColor(next_state["color"], cv2.COLOR_RGB2BGR)
                    img = self.action_viz(img, action, unmasked_pred)

                    if self.cfgs['unfold']:
                        cv2.imwrite(f'{self.save_dir}/{i}_{j}_unfold.png',img)
                    else:
                        cv2.imwrite(f'{self.save_dir}/{i}_{j}_{l}_fold.png',img)

                    # debug
                    _, unsdepth = self.env.get_rgbd()
                    np.save(f'{self.save_dir}/unscaled/{i}_{j}_{l}_nobs.npy', unsdepth)
                    if 'flow' in info:
                        flim = info['flow']
                        np.save(f'{self.save_dir}/unscaled/{i}_{j}_{l}_flow.npy', flim)

            metrics,normed = self.calc_metric(start=coords_o[0,:,:3], goal=coords_n[-1][0,:,:3])
            total_metrics.append(metrics)
            normed_metrics.append(normed)
            print(f"goal {i}: {metrics},  {normed} (norm)")

        print("configs \n")
        print(self.cfgs)

        print(f"average action time: {np.mean(times)}")

        print("\nmean, std metrics: ",np.mean(total_metrics), np.std(total_metrics))
        print("normalized metrics: ",np.mean(normed_metrics))

        np.save(f"{self.save_dir}/actions.npy",actions)
        return total_metrics, normed_metrics


    def get_gaussian(self, u, v, sigma=5, size=None):
        """return gaussian image
        """
        if size is None:
            size = self.cfgs["im_width"]

        x0, y0 = torch.Tensor([u]).cuda(), torch.Tensor([v]).cuda()
        x0 = x0[:, None]
        y0 = y0[:, None]

        N = 1
        num = torch.arange(size).float()
        x, y = torch.vstack([num]*N).cuda(), torch.vstack([num]*N).cuda()
        gx = torch.exp(-(x-x0)**2/(2*sigma**2))
        gy = torch.exp(-(y-y0)**2/(2*sigma**2))
        g = torch.einsum('ni,no->nio', gx, gy)

        gmin = g.amin(dim=(1,2))
        gmax = g.amax(dim=(1,2))
        g = (g - gmin[:,None,None])/(gmax[:,None,None] - gmin[:,None,None])
        g = g.unsqueeze(1)

        if False:
            import matplotlib.pyplot as plt
            for i in range(g.shape[0]):
                plt.imshow(g[i].squeeze().detach().cpu().numpy())
                plt.show()

        return g


    def get_pt(self, logits, min_r=3):
        # select 2 pts with NMS
        N = logits.size(0)
        W = logits.size(2)

        if self.cfgs['prob_type'] == 'sigmoid':
            probs = torch.sigmoid(logits)
            probs = probs.view(N,1,W*W)
        else:
            probs = F.softmax(logits.flatten(-2), -1)

        val,idx = torch.max(probs[:,0], 1)
        u = (idx // 20) * 10
        v = (idx % 20) * 10

        return u.item(),v.item()


    def nearest_to_mask(self, u, v, depth):
        mask_idx = np.argwhere(depth)
        nearest_idx = mask_idx[((mask_idx - [u,v])**2).sum(1).argmin()]

        return nearest_idx


    def get_action(self, obs, nobs):
        """ run inference
        """
        obs = obs.cuda()
        nobs = nobs.cuda()

        if self.cfgs['input_mode'] == 'noflow':
            x1 = torch.cat([obs, nobs], dim=1)
        elif self.cfgs['input_mode'] == 'obsflow':
            x1 = torch.cat([obs, flow], dim=1)
        else:
            x1 = flow

        # model inference
        if self.model.cfg[self.model.cfg.nettype]['predictplace']:
            flow, [pick_u1, pick_v1], [pick_u2, pick_v2], logits1, logits2, [place_u1, place_v1], [place_u2,place_v2], logits1p, logits2p, info = self.model(x1)
            place_u1 = place_u1.detach().squeeze().cpu().numpy()
            place_v1 = place_v1.detach().squeeze().cpu().numpy()
            place_u2 = place_u2.detach().squeeze().cpu().numpy()
            place_v2 = place_v2.detach().squeeze().cpu().numpy()
        else:
            flow, [pick_u1, pick_v1], [pick_u2, pick_v2], logits1, logits2, info = self.model(x1)

        # get masked pick point
        depth_arr = obs.cpu().detach().numpy()[0,0]
        pick1 = self.nearest_to_mask(pick_u1, pick_v1, depth_arr)
        pick2 = self.nearest_to_mask(pick_u2, pick_v2, depth_arr)
        pickmask_u1,pickmask_v1 = pick1
        pickmask_u2,pickmask_v2 = pick2
        
        if self.model.cfg[self.model.cfg.nettype]['predictplace']:
            # Check if intersecting
            if intersect((pickmask_u1,pickmask_v1),(place_u1,place_v1),(pickmask_u2,pickmask_v2),(place_u2,place_v2)):
                place1 = np.array([place_u2, place_v2])
                place2 = np.array([place_u1, place_v1])
            else:
                place1 = np.array([place_u1, place_v1])
                place2 = np.array([place_u2, place_v2])
        else:
            flow_arr = flow.detach().cpu().numpy()[0]
            place_u1,place_v1 = self.get_flow_place_pt(pickmask_u1,pickmask_v1,flow_arr)
            place_u2,place_v2 = self.get_flow_place_pt(pickmask_u2,pickmask_v2,flow_arr)
            place1 = np.array([place_u1, place_v1])
            place2 = np.array([place_u2, place_v2])
            
        pred_1 = np.array([pick_u1, pick_v1])
        pred_2 = np.array([pick_u2, pick_v2])

        # single action threshold
        if self.cfgs['s_pick_thres'] > 0 and np.linalg.norm(pick1-pick2) < self.cfgs['s_pick_thres']:
            pick2 = np.array([0,0])
            place2 = np.array([0,0])

        # action size threshold
        if self.cfgs['a_len_thres'] > 0 and np.linalg.norm(pick1-place1) < self.cfgs['a_len_thres']:
            pick1 = np.array([0,0])
            place1 = np.array([0,0])

        if self.cfgs['a_len_thres'] > 0 and np.linalg.norm(pick2-place2) < self.cfgs['a_len_thres']:
            pick2 = np.array([0,0])
            place2 = np.array([0,0])

        # debug viz
        if False:
            i = 0
            # obs and nobs images
            img = obs[i,0,:,:].detach().cpu().numpy()
            img = cv2.cvtColor(img,cv2.COLOR_GRAY2RGB)
            nimg = nobs[i,0,:,:].detach().cpu().numpy()
            nimg = cv2.cvtColor(nimg,cv2.COLOR_GRAY2RGB)

            # predicted pick pts vs gt
            u1 = pick_u1
            v1 = pick_v1
            cv2.circle(img, (v1,u1), 6, (0,0,1), 2)
            u2 = pick_u2
            v2 = pick_v2
            cv2.circle(img, (v2,u2), 6, (0,0,1), 2)
            u,v = pick1
            cv2.circle(img, (v,u), 6, (1,1,0), 2)
            u,v = pick2
            cv2.circle(img, (v,u), 6, (1,1,0), 2)

            u,v = place1
            cv2.circle(img, (int(v),int(u)), 6, (1,0,1), 2)
            u,v = place2
            cv2.circle(img, (int(v),int(u)), 6, (1,0,1), 2)

            # inputs viz
            fig, ax = plt.subplots(1,3, figsize=(6,3))
            fig.suptitle('inputs')
            ax[0].imshow(img)
            ax[0].set_title('obs')
            ax[1].imshow(nimg)
            ax[1].set_title('nobs')
            flow = flow[0].permute(1, 2, 0).detach().cpu().numpy()
            plot_flow(ax[2], flow, skip=1.0)

            plt.show()
            plt.close()

            probs1 = torch.sigmoid(logits1)
            probs2 = torch.sigmoid(logits2)

            fig, ax = plt.subplots(1,4)
            fig.suptitle('outputs')
            ax[0].imshow(logits1[i,0,:,:].clone().detach().cpu().numpy())
            ax[0].set_title('logits 1')
            ax[1].imshow(probs1[i,0,:,:].clone().detach().cpu().numpy())
            ax[1].set_title('heatmap 1')
            ax[2].imshow(logits2[i,0,:,:].clone().detach().cpu().numpy())
            ax[2].set_title('logits 2')
            ax[3].imshow(probs2[i,0,:,:].clone().detach().cpu().numpy())
            ax[3].set_title('heatmap 2')
            plt.show()
            plt.close()

        print(f"{pick1} {place1} {pick2} {place2}")

        info = {}
        if not self.model.cfg[self.model.cfg.nettype]['predictplace']:
            info['flow'] = flow_arr            
        return np.array([pick1, place1, pick2, place2]), np.array([pred_1, pred_2]), info

    def calc_metric(self, start, goal):
        current = pyflex.get_positions().reshape(-1, 4)[:,:3]
        start_dist = np.linalg.norm(goal - start, axis=1).mean()
        curr_dist = np.linalg.norm(goal - current, axis=1).mean()
        normed = curr_dist/start_dist

        return curr_dist, normed


    def get_flow_place_pt(self, u,v, flow):
        """ compute place point using flow
        """
        flow_u_idxs = np.argwhere(flow[0,:,:])
        flow_v_idxs = np.argwhere(flow[1,:,:])
        nearest_u_idx = flow_u_idxs[((flow_u_idxs - [u,v])**2).sum(1).argmin()]
        nearest_v_idx = flow_v_idxs[((flow_v_idxs - [u,v])**2).sum(1).argmin()]

        flow_u = flow[0,nearest_u_idx[0],nearest_u_idx[1]]
        flow_v = flow[1,nearest_v_idx[0],nearest_v_idx[1]]

        new_u = np.clip(u + flow_u, 0, 199)
        new_v = np.clip(v + flow_v, 0, 199)

        return new_u,new_v


    def rotate(self, obs, goal, angle):
        obs = TF.affine(obs, angle=angle, translate=(0, 0), scale=1.0, shear=0)
        goal = TF.affine(goal, angle=angle, translate=(0, 0), scale=1.0, shear=0)
        return obs, goal


    def aug_uv(self, uv, angle, dx, dy, size=719):
        uvt = deepcopy(uv)
        rad = np.deg2rad(angle)
        R = np.array([
            [np.cos(rad), -np.sin(rad)],
            [np.sin(rad), np.cos(rad)]])
        uvt -= size / 2
        uvt = np.dot(R, uvt.T).T
        uvt += size / 2
        uvt[:, 1] += dx
        uvt[:, 0] += dy

        uvt = np.clip(uvt, 0, size)
        return uvt


    def pad(self, im, pad_amount):
        trans = transforms.Compose([transforms.Pad(pad_amount),
                                    transforms.Resize((200,200))])
        padded_im = trans(im)
        return padded_im

    def action_viz(self, img, action, unmasked_pred):
        ''' img: cv2 image
            action: pick1, place1, pick2, place2
            unmasked_pred: pick1_pred, pick2_pred'''
        pick1, place1, pick2, place2 = action
        pick1_pred, pick2_pred = unmasked_pred

        # draw the masked action
        u1,v1 = pick1
        u2,v2 = place1
        cv2.circle(img, (int(v1),int(u1)), 6, (0,200,0), 2)
        cv2.arrowedLine(img, (int(v1),int(u1)), (int(v2),int(u2)), (0,200,0), 3)
        u1,v1 = pick2
        u2,v2 = place2
        cv2.circle(img, (int(v1),int(u1)), 6, (0,200,0), 2)
        cv2.arrowedLine(img, (int(v1),int(u1)), (int(v2),int(u2)), (0,200,0), 3)

        return img


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--goals', help="which cloth type to run", required=True, choices=['towel', 'tsh', 'rect'], default='towel')
    parser.add_argument('--output_dir', help="base dir to models", default="/data/fabric_data/joint_runs")
    parser.add_argument('--run_name', help="model to evaluate", required=True)
    parser.add_argument('--checkpoint', help="checkpoint to load, otherwise all will be evaluated", type=int, default=-1)
    parser.add_argument('--repeat', help="number of iterative corrections per subgoal", type=int, default=1)
    parser.add_argument('--input_mode', help="type of input", choices=['noflow', 'flowonly'])
    parser.add_argument('--headless', help="turn on headless", dest='headless', action='store_true')
    parser.add_argument('--no-headless', help="turn off headless", dest='headless', action='store_false')
    parser.add_argument('--overwrite', help="overwrite", dest='overwrite', action='store_true')
    parser.add_argument('--no-overwrite', help="don't overwrite", dest='overwrite', action='store_false')
    parser.add_argument('--skip', help="how many epochs to skip when evaluating", default=5)
    args = parser.parse_args()

    # goals = 'towel' #os #ms #towel #tsh #rect #large #debug # goal mode
    if args.goals == 'tsh':
        cloth_type = 'tshirt'
    else:
        cloth_type = 'towel'

    if args.goals == 'rect':
        shape = 'rect'
    elif args.goals == 'large':
        shape = 'large'
    else:
        shape = 'default'

    edgethresh = 10 if cloth_type == 'tshirt' else 5
    cfgs = {
        'seed': 0,
        'im_width': 200,
        'img_type': 'depth',
        'on_table': True,
        'horizon': 1,
        'run_name': args.run_name,
        'load_iter': 0, 
        'goal_folder': '',
        'output_dir': args.output_dir,
        'actmask': 'none', # none, edge, corner
        'res_name': '',
        'res_iter': 0,
        'unfold': False,
        'prob_type': 'sigmoid',
        'load_finetuned': False,
        'flow_align': False,
        'headless': args.headless, # set this for normal eval
        'cloth_type': cloth_type,
        'shape': shape,
        'edgethresh': edgethresh,
        'input_mode': args.input_mode, #noflow #obsflow #flowonly # use for depthin or noflow ablation
        'goals': args.goals,
        's_pick_thres': 30, #30 #single arm action threshold
        'a_len_thres': 10, # action length treshold. default 0 for repeat=1
        'repeat': args.repeat, # number of iterative corrective actions default: 1
        'record': False, # record videos, the window position is hardcoded needs to be changed, set headless = False
        'skip': args.skip
    }

    avg_metrics = []
    full_mean = []

    cfgs['unfold'] = False
    cfgs['load_iter'] = 300
    env = EnvRollout(cfgs)

    # loop through the checkpoints to evaluate
    ckpnts = sorted([int(x.split('_')[0]) for x in os.listdir(f'{env.save_dir}') if 'fold.png' in x])

    rng = range(300, 501, args.skip) if args.checkpoint==-1 else range(args.checkpoint, args.checkpoint+1, args.skip)
    for i in rng:
        print(i)
        if i in ckpnts and not args.overwrite:
            continue

        print(f"loading {i}")
        print("folding:")
        try:
            env.cfgs['load_iter'] = i
            env.load_model()

            fold_mean, fold_norm = env.run()

            full_mean.append(fold_mean)
            print(f"mean: {np.mean(fold_mean)} norm: {np.mean(fold_norm)}")

            avg_metrics.append([i, np.mean(fold_mean), np.std(fold_mean), np.mean(fold_norm), np.std(fold_norm)])
            
            full_mean_np = np.array(full_mean)
            avg_metrics_np = np.array(avg_metrics)
            np.save(f'{env.save_dir}/fold_metrics.npy',avg_metrics)
            np.save(f'{env.save_dir}/all_goals.npy',full_mean)
        except EOFError:
            print("EOFError. skipping...")
        except FileNotFoundError:
            print("File not found, skipping...")
        except Exception: 
            print("Exception, skipping...")

    avg_metrics = np.array(avg_metrics)
    idx = avg_metrics[:,1].argmin()
    print(f"\nmin: {avg_metrics[idx,0]} fold mean/std: {avg_metrics[idx,1]*1000.0:.3f} {avg_metrics[idx, 2]*1000.0:.3f} norm {avg_metrics[idx,3]*1000.0:.3f} {avg_metrics[idx, 4]*1000.0:.3f} ")

    print("\nall goals:")
    for m in full_mean[idx]:
        print(m)

    if cfgs['goals'] == 'towel':
        full_mean = np.array(full_mean)*1000.0
        print(f"\nall: {np.mean(full_mean):.3f} {np.std(full_mean):.3f}")
        print(f"one-step: {np.mean(full_mean[idx,:40]):.3f} {np.std(full_mean[idx,:40]):.3f}")
        print(f"mul-step: {np.mean(full_mean[idx,40:]):.3f} {np.std(full_mean[idx,40:]):.3f}")
        print(f"one-arm: {np.mean(full_mean[idx,np.r_[:24,32:42,44]]):.3f} {np.std(full_mean[idx,np.r_[:24,32:42,44]]):.3f}")
        print(f"two-arm: {np.mean(full_mean[idx,np.r_[24:32,42,43,45]]):.3f} {np.std(full_mean[idx,np.r_[24:32,42,43,45]]):.3f}")

        with open(f'{env.save_dir}/metrics.txt', "w") as f:
            f.write(f'all mean/std: {np.mean(full_mean):.3f} {np.std(full_mean):.3f}')
            f.write(f'one-step mean/std: {np.mean(full_mean[idx,:40]):.3f} {np.std(full_mean[idx,:40]):.3f}')
            f.write(f'mul-step mean/std: {np.mean(full_mean[idx,40:]):.3f} {np.std(full_mean[idx,40:]):.3f}')
            f.write(f'one-arm mean/std: {np.mean(full_mean[idx,np.r_[:24,32:42,44]]):.3f} {np.std(full_mean[idx,np.r_[:24,32:42,44]]):.3f}')
            f.write(f'two-arm mean/std: {np.mean(full_mean[idx,np.r_[24:32,42,43,45]]):.3f} {np.std(full_mean[idx,np.r_[24:32,42,43,45]]):.3f}')

    np.save(f'{env.save_dir}/fold_metrics.npy',avg_metrics)
    np.save(f'{env.save_dir}/all_goals.npy',full_mean)

    fig, ax = plt.subplots(1,2, figsize=(10,3))
    ax[0].set_title("mean")
    ax[0].set_xlabel("step")
    ax[0].set_ylabel("mean metric")
    ax[0].plot(avg_metrics[:,0], avg_metrics[:,1], label='fold')
    ax[0].legend()
    ax[1].set_title("norm")
    ax[1].set_xlabel("step")
    ax[1].set_ylabel("norm metric")
    ax[1].plot(avg_metrics[:,0], avg_metrics[:,2], label='fold')
    plt.savefig(f'{env.save_dir}/fold_eval.png')
    plt.show()
