import torch
import numpy as np
from tqdm import tqdm
import imageio

from .utils import visualize as vis
from .utils.common import mse2psnr, reduce_dict, gather_all
from .utils.common import get_rank, get_world_size

import os
from collections import defaultdict

class SRTTrainer:
    def __init__(self, model, optimizer, cfg, device, out_dir, render_kwargs):
        self.model = model
        self.optimizer = optimizer
        self.config = cfg
        self.device = device
        self.out_dir = out_dir
        self.render_kwargs = render_kwargs

        self.num_bg = cfg['model']['encoder_kwargs'].get('num_bg', False)
        self.num_slots = cfg['model']['encoder_kwargs']['num_slots']
        self.slot_dim = cfg['model']['encoder_kwargs']['slot_dim']

        # Loss
        self.mask_loss = cfg['training']['mask']
        self.mask_loss_scale = cfg['training']['mask_scale']


    def evaluate(self, val_loader, **kwargs):
        ''' Performs an evaluation.
        Args:
            val_loader (dataloader): pytorch dataloader
        '''
        self.model.eval()
        eval_lists = defaultdict(list)

        loader = val_loader if get_rank() > 0 else tqdm(val_loader)
        sceneids = []

        for data in loader:
            sceneids.append(data['sceneid'])
            eval_step_dict = self.eval_step(data, **kwargs)

            for k, v in eval_step_dict.items():
                eval_lists[k].append(v)

        sceneids = torch.cat(sceneids, 0).cuda()
        sceneids = torch.cat(gather_all(sceneids), 0)

        print(f'Evaluated {len(torch.unique(sceneids))} unique scenes.')

        eval_dict = {k: torch.cat(v, 0) for k, v in eval_lists.items()}
        eval_dict = reduce_dict(eval_dict, average=True)  # Average across processes
        eval_dict = {k: v.mean().item() for k, v in eval_dict.items()}  # Average across batch_size
        print('Evaluation results:')
        print(eval_dict)
        return eval_dict

    def train_step(self, data, it):
        self.model.train()
        self.optimizer.zero_grad()
        loss, loss_terms = self.compute_loss(data, it)
        loss = loss.mean(0)
        loss_terms = {k: v.mean(0).item() for k, v in loss_terms.items()}
        loss.backward()
        self.optimizer.step()
        return loss.item(), loss_terms

    def compute_loss(self, data, it):
        device = self.device

        # Data
        input_images = data.get('input_images').to(device)
        input_camera_pos = data.get('input_camera_pos').to(device)
        input_rays = data.get('input_rays').to(device)
        
        target_pixels = data.get('target_pixels').to(device)
        target_camera_pos = data.get('target_camera_pos').to(device)
        target_rays = data.get('target_rays').to(device)

        # Encode
        z = self.model.encoder(input_images, input_camera_pos, input_rays)
        loss = 0.
        loss_terms = dict()

        T = target_pixels.shape[1]
        pred_pixels_list = []
        weights_list = []
        slot_feats_list = []
        slot_pixel_list = []

        # Decode
        for t in range(T):
            pred_pixel, slot_pixel, weight, slot_feats= self.model.decoder(z, target_camera_pos[:, t], target_rays[:, t], **self.render_kwargs)
            weights_list.append(weight)            
            pred_pixels_list.append(pred_pixel)
            slot_feats_list.append(slot_feats)
            slot_pixel_list.append(slot_pixel)

        pred_pixels = torch.stack(pred_pixels_list, dim=1)
        weights = torch.stack(weights_list, dim=1)

        # Loss 1 : Recon loss
        loss = loss + ((pred_pixels - target_pixels)**2).mean((1, 2, 3))
        loss_terms['mse'] = loss
   
        # Loss 2 : BG/AG Mask loss
        if self.mask_loss:
            fg_mask = data.get('fg_mask').to(device)
            ag_mask = data.get('ag_mask').to(device)
            target_pixel_indices = data.get('target_index').to(device)

            B, _, _ = target_pixel_indices.shape
            fg_mask = fg_mask.flatten(-2, -1).reshape(B, T, 1, -1)[:, :, 0 ]
            ag_mask = ag_mask.flatten(-2, -1).reshape(B, T, 1, -1)[:, :, 0 ]

            fg_attn_targets = torch.gather(fg_mask, 2, target_pixel_indices)
            ag_attn_targets = torch.gather(ag_mask, 2, target_pixel_indices)
            bg_attn_targets = 1 - fg_attn_targets.clone() - ag_attn_targets.clone()

            bg_attn = weights[..., :self.num_bg]
            ag_attn = weights[..., -1].unsqueeze(-1)

            attns_bg_pixels = bg_attn.sum(-1, keepdim=True) * pred_pixels
            attns_ag_pixels = ag_attn.sum(-1, keepdim=True) * pred_pixels
            
            bg_attn_target_pixels = bg_attn_targets[..., None] * target_pixels 
            ag_attn_target_pixels = ag_attn_targets[..., None] * target_pixels
            
            bg_mask_loss = torch.nn.MSELoss(reduction='none')(
                attns_bg_pixels, bg_attn_target_pixels).mean([1, 2, 3])
            ag_mask_loss = torch.nn.MSELoss(reduction='none')(
                attns_ag_pixels, ag_attn_target_pixels).mean([1, 2, 3])
                
            loss_terms['bg_mask_loss'] = bg_mask_loss
            loss_terms['ag_mask_loss'] = ag_mask_loss

            loss = loss + self.mask_loss_scale * bg_mask_loss
            loss = loss + self.mask_loss_scale * ag_mask_loss

        return loss, loss_terms

    def eval_step(self, data, full_scale=False):
        with torch.no_grad():
            loss, loss_terms = self.compute_loss(data, 1000000)

        mse = loss_terms['mse']
        psnr = mse2psnr(mse)

        return {'psnr': psnr, 'mse': mse, **loss_terms}

    def visualize(self, data, mode='val'):
            self.model.eval()

            with torch.no_grad():
                device = self.device
                input_images = data.get('input_images').to(device)
                input_camera_pos = data.get('input_camera_pos').to(device)
                input_rays = data.get('input_rays').to(device)
                target_images = data.get('target_pixels_ori').to(device) #target_images.reshape(-1, 3, 128, 128)
                target_camera_pos = data.get('target_camera_pos_ori').to(device)
                target_rays = data.get('target_rays_ori').to(device)
                #data.get('target_pixels_flatten').to(device).reshape(1, 3, 128, 128, 3).permute(0, 1, 4, 2, 3).reshape(-1, 3, 128, 128)

                input_images_np = np.transpose(input_images.cpu().numpy(), (0, 1, 3, 4, 2)) #[B, T, H, W, 3]
                target_images_np = np.transpose(target_images.cpu().numpy(), (0, 1, 3, 4, 2))

                z = self.model.encoder(input_images, input_camera_pos, input_rays)

                batch_size, num_input_images, height, width, _ = input_rays.shape

                columns = []
                for i in range(num_input_images):
                    header = 'input' if num_input_images == 1 else f'input {i+1}'
                    columns.append((header, input_images_np[:, i], 'image'))

                row_labels = None

                _, num_target_views, _ = target_camera_pos.size()
                for i in range(num_target_views): #img.permute(0, 3, 1, 2)
                    img, slots, extras = self.render_image(z, target_camera_pos[:, i], target_rays[:, i], **self.render_kwargs)
                    columns.append((f'GT {i}', target_images_np[:, i], 'image'))
                    columns.append((f'target view {i}', img.cpu().numpy(), 'image'))
                    columns.append((f'slots {i}', slots.cpu().numpy(), 'image'))
                    
                    #img.permute(0, 3, 1, 2)

                output_img_path = os.path.join(self.out_dir, f'renders-{mode}')
                vis.draw_visualization_grid(columns, output_img_path, row_labels=row_labels)


    def render_image(self, z, camera_pos, rays, **render_kwargs):
        """
        Args:
            z [n, k, c]: set structured latent variables
            camera_pos [n, 3]: camera position
            rays [n, h, w, 3]: ray directions
            render_kwargs: kwargs passed on to decoder
        """
        batch_size, height, width = rays.shape[:3]
        rays = rays.flatten(1, 2)
        camera_pos = camera_pos.unsqueeze(1).repeat(1, rays.shape[1], 1)
        max_num_rays = self.config['data']['num_points'] * \
                self.config['training']['batch_size'] // (rays.shape[0] * get_world_size())
        num_rays = rays.shape[1]
        img = torch.zeros_like(rays)
        all_extras = []

        img_slots = img.clone()[:, None].repeat(1, z.size(1), 1, 1)
        attn_slots = img.clone()[:, None].repeat(1, z.size(1), 1, 1)   

        for i in range(0, num_rays, max_num_rays):
            img[:, i:i+max_num_rays], extras, attn, slot_feats = self.model.decoder(
                z, camera_pos[:, i:i+max_num_rays], rays[:, i:i+max_num_rays],
                **render_kwargs)
            attn = attn[:, None].permute(0, 3, 2, 1)     
            img_slots[:, :, i:i + max_num_rays] = extras * attn
            attn_slots[:, :, i:i + max_num_rays] = attn.repeat(1, 1, 1, 3)    
            all_extras.append(extras)
            
        extras = {}                        
        COLORS = [
                [60, 255, 60],    # Muted Lime
                [255, 60, 60],    # Muted Red
                [60, 60, 255],    # Muted Blue
                [60, 255, 255],   # Muted Cyan
                [255, 60, 255],   # Muted Magenta
                [255, 255, 60],   # Muted Yellow
                [200, 100, 100],  # Muted Maroon
                [100, 100, 200],  # Muted Navy
                [255, 90, 157],   # Muted DeepPink
                [208, 60, 251],   # Muted DarkViolet
                [255, 202, 213],  # Light Pink
                [203, 236, 250],  # Light Blue
                [164, 248, 164],  # Light Green
                [250, 148, 148],  # Light Coral
                [241, 180, 241],  # Plum
                [132, 225, 190],  # Medium Aquamarine
                [253, 170, 142],  # Dark Salmon
                [206, 244, 250],  # Powder Blue
                [250, 240, 170],  # Khaki
                [240, 230, 255]   # Lavender
            ]

        popular_colors_tensor = torch.tensor(COLORS, dtype=torch.float32)

        indices = torch.randperm(20)[:128] # slot_dimension
        color_codes = popular_colors_tensor[indices]    

        colored_masks = torch.zeros_like(attn_slots).to(attn_slots.device)
        for i in range(attn_slots.size(1)): 
            mask = attn_slots[0, i, :, 0].clone().view(height, width)

            for j in range(3):  # R, G, B channels
                colored_masks[0, i, :, j] = (mask * color_codes[i][j]).view(-1)

        colored_masks = colored_masks / 255.0

        slots = img_slots
        attn = colored_masks

        slots = slots.reshape(slots.shape[0], height * z.size(1), width, 3)
        attn = attn.reshape(slots.shape[0], height * z.size(1), width, 3)

        htanh = torch.nn.Identity()
        slots = torch.cat([htanh(slots), torch.clamp(attn, 0.0, 1.0)], dim=2)
            
        img = img.view(img.shape[0], height, width, 3)

        return img, slots, attn_slots
