import time
import os.path as osp
import numpy as np
from copy import deepcopy

import torch
import torch.nn.functional as F

from model.trainer.base import Trainer
from model.trainer.helpers import (
    get_dataloader, prepare_model, prepare_optimizer, get_cross_shot_dataloader, get_class_dataloader
)
from model.utils import (
    pprint, ensure_path,
    Averager, Timer, count_acc,
    compute_confidence_interval,
)
from tensorboardX import SummaryWriter
from tqdm import tqdm
import torch.nn as nn
import datetime
import pandas as pd
class FSLTrainer(Trainer):
    def __init__(self, args):
        super().__init__(args)

        self.train_loader, self.val_loader  , self.test_loader = get_dataloader(args)
        self.model= prepare_model(args)
        self.optimizer, self.lr_scheduler = prepare_optimizer(self.model, args)
        self.cos = nn.CosineSimilarity(dim=-1).cuda()
        self.softmax = nn.Softmax(dim=1)
        self.logsoftmax = nn.LogSoftmax(dim=1)
        self.temperature = 2
        self.kl = nn.KLDivLoss(reduction='batchmean')
        self.mse = torch.nn.MSELoss(reduction='mean')
        # save running statistics
        running_dict = {}
        for e in self.model.encoder.state_dict():
            if 'running' in e:
                key_name = '.'.join(e.split('.')[:-1])
                if key_name in running_dict:
                    continue
                else:
                    running_dict[key_name] = {}
                # find the position of BN modules
                component = self.model.encoder
                for att in key_name.split('.'):
                    if att.isdigit():
                        component = component[int(att)]
                    else:
                        component = getattr(component, att)
                
                running_dict[key_name]['mean'] = component.running_mean
                running_dict[key_name]['var'] = component.running_var
        self.running_dict = running_dict          
                
                        
    def prepare_label(self):
        # prepare one-hot label
        args = self.args
        label = torch.arange(args.way, dtype=torch.int16).repeat(args.query)
        label = label.type(torch.LongTensor)
        if torch.cuda.is_available():
            label = label.cuda()
        return label
    
    def train(self):
        outloss = []
        similarityloss = []
        diffloss = []
        
        args = self.args
        self.model.train()
        time_now = str(datetime.datetime.now())
        if self.args.fix_BN:
            self.model.encoder.eval()
         
        # start FSL training
        label = self.prepare_label()
        for epoch in range(1, args.max_epoch + 1):
            # initialize the repo with embeddings
            self.train_epoch += 1
            self.model.train()
            if self.args.fix_BN:
                self.model.encoder.eval()
            
            tl1, tl2, ta = Averager(), Averager(), Averager()
            start_tm = time.time()
            self.model.zero_grad()

            for batch in tqdm(self.train_loader):
                outer_loss = 0
                embeddings = []
                predicts = []
                initial_embeddings = []
                self.train_step += 1
                
                diff_loss = 0 
                if args.vallina_maml:
                    data, aug_data, gt_label = batch
                        
                    gt_label = gt_label[:args.way] # get the ground-truth label of the current episode
                    data_tm = time.time()
                    self.dt.add(data_tm - start_tm)

                    support = data[:args.way * args.shot].cuda()
                    query = data[args.way * args.shot:].cuda()
                    logits , embedding = self.model(support, query)


                    outer_loss += F.cross_entropy(logits, label)
                        
                    #similarity_loss.div_(args.batchsize)
                    outer_loss.div_(args.batchsize)
                    #diff_loss.div_(args.batchsize)
                    
                    outloss.append(outer_loss.detach().cpu().item())
                    similarityloss.append(0)
                    diffloss.append(0)
                    total_loss = outer_loss

                else:

                    data, aug_data, gt_label = batch
                    for i in range(args.batchsize):
                        
                        gt_label = gt_label[:args.way] # get the ground-truth label of the current episode
                        data_tm = time.time()
                        self.dt.add(data_tm - start_tm)

                        support = aug_data[i][:args.way * args.shot].cuda()
                        query = data[args.way * args.shot:].cuda()
                        logits , embedding = self.model(support, query)

                        embeddings.append(logits)

                        outer_loss += F.cross_entropy(logits, label)
                        
                        
                    embeddings = torch.stack(embeddings)


                    # embedding_copy = embeddings.detach()
                    # cos_sim_matrix = self.cos(embeddings.unsqueeze(1), embedding_copy.unsqueeze(0))
                    # similarity_loss = -1*cos_sim_matrix.mean()

                    similarity_loss = 0
                    center = embeddings.mean(0).detach()
                    for k in range(embeddings.shape[0]):
                        similarity_loss += -1*self.cos(embeddings[k],center).mean()
                    similarity_loss.div_(args.batchsize)

                    outer_loss.div_(args.batchsize)
                    #diff_loss.div_(args.batchsize)
                    
                    outloss.append(outer_loss.detach().cpu().item())
                    similarityloss.append(similarity_loss.detach().cpu().item())
                    diffloss.append(0)
                    if args.only_aug:
                        total_loss = 1*outer_loss 
                    else:   
                        total_loss = 1*outer_loss + 1*similarity_loss

                    
                tl2.add(total_loss.item())
                
                forward_tm = time.time()
                self.ft.add(forward_tm - data_tm)
                acc = count_acc(logits, label)
                
                tl1.add(total_loss.item())
                ta.add(acc)
                
                total_loss.backward()
                backward_tm = time.time()
                self.bt.add(backward_tm - forward_tm)
                self.optimizer.step()
                optimizer_tm = time.time()
                self.ot.add(optimizer_tm - backward_tm)                    
                self.model.zero_grad()
                    
                self.try_logging(tl1, tl2, ta)
                # refresh start_tm
                start_tm = time.time()
                
            df = pd.DataFrame([outloss,similarityloss,diffloss]).T
            df.columns=['out_loss','similarity_loss','diff_loss']
            df.to_csv('record/'+ time_now +'.csv')

            self.lr_scheduler.step()
            print('LOG: Epoch-{}: Train Acc-{}'.format(epoch, acc))
            self.try_evaluate(epoch)

            print('ETA:{}/{}'.format(
                    self.timer.measure(),
                    self.timer.measure(self.train_epoch / args.max_epoch))
            )

        torch.save(self.trlog, osp.join(args.save_path, 'trlog'))
        self.save_model('epoch-last')

    def evaluate(self, data_loader):
        # restore model args
        args = self.args
        args.old_way, args.old_shot, args.old_query = args.way, args.shot, args.query
        args.way, args.shot, args.query = args.eval_way, args.eval_shot, args.eval_query
        # evaluation mode
        self.model.eval()
        # record the runing mean and variance before validation
        for e in self.running_dict:
            self.running_dict[e]['mean_copy'] = deepcopy(self.running_dict[e]['mean'])
            self.running_dict[e]['var_copy'] = deepcopy(self.running_dict[e]['var'])
            
        record = np.zeros((args.num_eval_episodes, 2)) # loss and acc
        label = torch.arange(args.eval_way, dtype=torch.int16).repeat(args.eval_query)
        label = label.type(torch.LongTensor)
        if torch.cuda.is_available():
            label = label.cuda()
        print('best epoch {}, best val acc={:.4f} + {:.4f}'.format(
                self.trlog['max_acc_epoch'],
                self.trlog['max_acc'],
                self.trlog['max_acc_interval']))
        for i, batch in enumerate(data_loader, 1):
            if torch.cuda.is_available():
                data = batch[0].cuda()
            else:
                data = batch[0]

            support = data[:args.eval_way * args.eval_shot]
            query = data[args.eval_way * args.eval_shot:]
            logits = self.model.forward_eval(support, query,step= args.eval_inner_iters)
            for e in self.running_dict:
                self.running_dict[e]['mean'] = deepcopy(self.running_dict[e]['mean_copy'])
                self.running_dict[e]['var'] = deepcopy(self.running_dict[e]['var_copy'])
                
            loss = F.cross_entropy(logits, label)
            acc = count_acc(logits, label)
            record[i-1, 0] = loss.item()
            record[i-1, 1] = acc
            del data, support, query, logits, loss
            torch.cuda.empty_cache()
            if i == record.shape[0]:
                break
            
        assert(i == record.shape[0])
        vl, _ = compute_confidence_interval(record[:,0])
        va, vap = compute_confidence_interval(record[:,1])
        
        # train mode
        self.model.train()
        if self.args.fix_BN:
            self.model.encoder.eval()
            self.model.encoder_repo.eval()

        args.way, args.shot, args.query = args.old_way, args.old_shot, args.old_query
        return vl, va, vap

    def evaluate_test(self):
        # restore model args
        args = self.args
        args.old_way, args.old_shot, args.old_query = args.way, args.shot, args.query
        args.way, args.shot, args.query = args.eval_way, args.eval_shot, args.eval_query        
        # evaluation mode
        self.model.load_state_dict(torch.load(osp.join(self.args.save_path, 'max_acc.pth'))['params'])
        self.model.eval()
        # record the runing mean and variance before validation
        for e in self.running_dict:
            self.running_dict[e]['mean_copy'] = deepcopy(self.running_dict[e]['mean'])
            self.running_dict[e]['var_copy'] = deepcopy(self.running_dict[e]['var'])        
        record = np.zeros((args.test_times, 2)) # loss and acc
        label = torch.arange(args.eval_way, dtype=torch.int16).repeat(args.eval_query)
        label = label.type(torch.LongTensor)
        if torch.cuda.is_available():
            label = label.cuda()
        print('best epoch {}, best val acc={:.4f} + {:.4f}'.format(
                self.trlog['max_acc_epoch'],
                    self.trlog['max_acc'],
                    self.trlog['max_acc_interval']))
        for i, batch in tqdm(enumerate(self.test_loader, 1)):
            if torch.cuda.is_available():
                data = batch[0].cuda()
            else:
                data = batch[0]
            support = data[:args.eval_way * args.eval_shot]
            query = data[args.eval_way * args.eval_shot:]
            logits = self.model.forward_eval(support, query)
            for e in self.running_dict:
                self.running_dict[e]['mean'] = deepcopy(self.running_dict[e]['mean_copy'])
                self.running_dict[e]['var'] = deepcopy(self.running_dict[e]['var_copy'])
                    
            loss = F.cross_entropy(logits, label)
            acc = count_acc(logits, label)
            record[i-1, 0] = loss.item()
            record[i-1, 1] = acc
            del data, support, query, logits, loss
            torch.cuda.empty_cache()
            
        assert(i == record.shape[0])
        vl, _ = compute_confidence_interval(record[:,0])
        va, vap = compute_confidence_interval(record[:,1])
    
        self.trlog['test_acc'] = va
        self.trlog['test_acc_interval'] = vap
        self.trlog['test_loss'] = vl
    
        print('best epoch {}, best val acc={:.4f} + {:.4f}\n'.format(
                self.trlog['max_acc_epoch'],
                    self.trlog['max_acc'],
                    self.trlog['max_acc_interval']))
        print('Test acc={:.4f} + {:.4f}\n'.format(
                self.trlog['test_acc'],
                    self.trlog['test_acc_interval']))
        
        args.way, args.shot, args.query = args.old_way, args.old_shot, args.old_query
        return vl, va, vap

    def evaluate_test_cross_shot(self):
        # restore model args
        args = self.args

        for step in [5,10,15,20,30,40]:
            if args.test_path is None:
                self.model.load_state_dict(torch.load(osp.join(self.args.save_path, 'max_acc.pth'))['params'])
            else:
                print(args.test_path)
                self.model.load_state_dict(torch.load(args.test_path)['params'])
            self.model.eval()
            # record the runing mean and variance before validation
            for e in self.running_dict:
                self.running_dict[e]['mean_copy'] = deepcopy(self.running_dict[e]['mean'])
                self.running_dict[e]['var_copy'] = deepcopy(self.running_dict[e]['var'])        
            # num_shots = [1, 5, 10, 20, 30, 50]
            num_shots = [5]
            record = np.zeros((args.test_times, len(num_shots))) # loss and acc
            label = torch.arange(args.eval_way, dtype=torch.int16).repeat(args.eval_query)
            label = label.type(torch.LongTensor)
            if torch.cuda.is_available():
                label = label.cuda()

            for s_index, shot in enumerate(num_shots):
                test_loader = get_cross_shot_dataloader(args, shot)
                args.eval_shot = shot
                args.old_way, args.old_shot, args.old_query = args.way, args.shot, args.query
                args.way, args.shot, args.query = args.eval_way, args.eval_shot, args.eval_query        
                for i, batch in tqdm(enumerate(test_loader, 1)):
                    if torch.cuda.is_available():
                        data = batch[0].cuda()
                    else:
                        data = batch[0]
                    support = data[:args.way * args.shot].cuda()
                    query = data[args.way * args.shot:].cuda()

                    logits = self.model.forward_eval(support, query, step)
                    
                    for e in self.running_dict:
                        self.running_dict[e]['mean'] = deepcopy(self.running_dict[e]['mean_copy'])
                        self.running_dict[e]['var'] = deepcopy(self.running_dict[e]['var_copy'])

                    loss = F.cross_entropy(logits, label)
                    acc = count_acc(logits, label)
                    record[i-1, s_index] = acc
                    del data, support, query, logits, loss
                    torch.cuda.empty_cache()
                    

                assert(i == record.shape[0])

                va, vap = compute_confidence_interval(record[:,s_index])
                print('Shot {} Test acc={:.4f} + {:.4f}\n'.format(shot, va, vap))
                args.way, args.shot, args.query = args.old_way, args.old_shot, args.old_query

            with open(osp.join(self.args.save_path, '{}+{}-CrossShot'.format(va, vap)), 'w') as f:
                f.write('best epoch {}, best val acc={:.4f} + {:.4f}\n'.format(
                        self.trlog['max_acc_epoch'],
                        self.trlog['max_acc'],
                        self.trlog['max_acc_interval']))                
                for s_index, shot in enumerate(num_shots):
                    va, vap = compute_confidence_interval(record[:,s_index])
                    f.write('Shot {} Test acc={:.4f} + {:.4f}\n'.format(shot, va, vap))

    def evaluate_test_cross_shot_sim(self):
        # restore model args
        args = self.args

        #for step in [5,10,15,20,30,40,50,60,70]:
        for step in [5,10,15,20]:
            
            if args.test_path is None:
                self.model.load_state_dict(torch.load(osp.join(self.args.save_path, 'max_acc.pth'))['params'])
            else:
                print(args.test_path)
                self.model.load_state_dict(torch.load(args.test_path)['params'])
           
            self.model.eval()
            # record the runing mean and variance before validation
            for e in self.running_dict:
                self.running_dict[e]['mean_copy'] = deepcopy(self.running_dict[e]['mean'])
                self.running_dict[e]['var_copy'] = deepcopy(self.running_dict[e]['var'])        
            # num_shots = [1, 5, 10, 20, 30, 50]
            num_shots = [1]
            record = np.zeros((args.test_times, len(num_shots))) # loss and acc
            label = torch.arange(args.eval_way, dtype=torch.int16).repeat(args.eval_query)
            label = label.type(torch.LongTensor)
            sim_list = np.zeros((args.test_times, len(num_shots)))
            if torch.cuda.is_available():
                label = label.cuda()

            for s_index, shot in enumerate(num_shots):
                test_loader = get_cross_shot_dataloader(args, shot)
                args.eval_shot = shot
                args.old_way, args.old_shot, args.old_query = args.way, args.shot, args.query
                args.way, args.shot, args.query = args.eval_way, args.eval_shot, args.eval_query        
                for i, batch in tqdm(enumerate(test_loader, 1)):
                    acc = 0
                    data, aug_data, gt_label = batch
                    embeddings = []
                    for n in range(args.batchsize):
                        
                        support = aug_data[n][:args.way * args.shot].cuda()
                        query = data[args.way * args.shot:].cuda()
                        logits = self.model.forward_eval(support, query, step)

                        embeddings.append(logits)
                        acc += count_acc(logits, label)

                    acc /= args.batchsize
                        
                        
                    embeddings = torch.stack(embeddings)

                    similarity_loss = 0
                    center = embeddings.mean(0).detach()
                    for k in range(embeddings.shape[0]):
                        similarity_loss += -1*self.cos(embeddings[k],center).mean()
                    similarity_loss.div_(args.batchsize)

       
                    for e in self.running_dict:
                        self.running_dict[e]['mean'] = deepcopy(self.running_dict[e]['mean_copy'])
                        self.running_dict[e]['var'] = deepcopy(self.running_dict[e]['var_copy'])

                    
                    record[i-1, s_index] = acc
                    sim_list[i-1, s_index] = similarity_loss.detach().cpu().item()
                    del data, support, query, logits
                    torch.cuda.empty_cache()

                assert(i == record.shape[0])

                if args.get_data:
                    df = pd.DataFrame([record[:,s_index],sim_list[:,s_index]])
                    df.to_csv(osp.join(self.args.save_path, 'step_'+str(step)+'data_{}.csv'.format(shot)), index=False, header=False)
                
                va, vap = compute_confidence_interval(record[:,s_index])
                print('Shot {} Test acc={:.4f} + {:.4f}\n'.format(shot, va, vap))
                va, vap = compute_confidence_interval(sim_list[:,s_index])
                print('Shot {} Test sim={:.4f} + {:.4f}\n'.format(shot, va, vap))
                args.way, args.shot, args.query = args.old_way, args.old_shot, args.old_query

            with open(osp.join(self.args.save_path, '{}+{}-CrossShot'.format(va, vap)), 'w') as f:
                f.write('best epoch {}, best val acc={:.4f} + {:.4f}\n'.format(
                        self.trlog['max_acc_epoch'],
                        self.trlog['max_acc'],
                        self.trlog['max_acc_interval']))                
                for s_index, shot in enumerate(num_shots):
                    va, vap = compute_confidence_interval(record[:,s_index])
                    f.write('Shot {} Test acc={:.4f} + {:.4f}\n'.format(shot, va, vap))
                    f.write('Shot {} Test sim={:.4f} + {:.4f}\n'.format(shot, va, vap))

    def evaluate_test_cross_shot_proto(self):
        # restore model args
        args = self.args
        args.bs = 4
        for step in [0,5,10,15,20,30]:
            
            self.model.load_state_dict(torch.load( 'max_acc_8468.pth')['params'])
            self.model.eval()
            # record the runing mean and variance before validation
            for e in self.running_dict:
                self.running_dict[e]['mean_copy'] = deepcopy(self.running_dict[e]['mean'])
                self.running_dict[e]['var_copy'] = deepcopy(self.running_dict[e]['var'])        
            # num_shots = [1, 5, 10, 20, 30, 50]
            num_shots = [5]
            record = np.zeros((args.test_times, len(num_shots))) # loss and acc
            label = torch.arange(args.eval_way, dtype=torch.int16).repeat(args.eval_query)
            label = label.type(torch.LongTensor)
            if torch.cuda.is_available():
                label = label.cuda()

            for s_index, shot in enumerate(num_shots):
                test_loader = get_cross_shot_dataloader(args, shot)
                args.eval_shot = shot
                args.old_way, args.old_shot, args.old_query = args.way, args.shot, args.query
                args.way, args.shot, args.query = args.eval_way, args.eval_shot, args.eval_query        
                for i, batch in tqdm(enumerate(test_loader, 1)):
                    if torch.cuda.is_available():
                        data = batch[0].cuda()
                    else:
                        data = batch[0]
                    support = data[:args.way * args.shot].cuda()
                    query = data[args.way * args.shot:].cuda()

                    logits = self.model.forward_eval(support, query, step, proto = True)
                    
                    for e in self.running_dict:
                        self.running_dict[e]['mean'] = deepcopy(self.running_dict[e]['mean_copy'])
                        self.running_dict[e]['var'] = deepcopy(self.running_dict[e]['var_copy'])

                    loss = F.cross_entropy(logits, label)
                    acc = count_acc(logits, label)

                    record[i-1, s_index] = acc
                    del data, support, query, logits, loss
                    torch.cuda.empty_cache()
                    

                assert(i == record.shape[0])

                va, vap = compute_confidence_interval(record[:,s_index])
                print('Shot {} Test acc={:.4f} + {:.4f}\n'.format(shot, va, vap))
                args.way, args.shot, args.query = args.old_way, args.old_shot, args.old_query

            with open(osp.join(self.args.save_path, '{}+{}-CrossShot'.format(va, vap)), 'w') as f:
                f.write('best epoch {}, best val acc={:.4f} + {:.4f}\n'.format(
                        self.trlog['max_acc_epoch'],
                        self.trlog['max_acc'],
                        self.trlog['max_acc_interval']))                
                for s_index, shot in enumerate(num_shots):
                    va, vap = compute_confidence_interval(record[:,s_index])
                    f.write('Shot {} Test acc={:.4f} + {:.4f}\n'.format(shot, va, vap))
                    
                    
    def final_record(self):
        # save the best performance in a txt file

        with open(osp.join(self.args.save_path, '{}+{}'.format(self.trlog['test_acc'], self.trlog['test_acc_interval'])), 'w') as f:
            f.write('best epoch {}, best val acc={:.4f} + {:.4f}\n'.format(
                    self.trlog['max_acc_epoch'],
                    self.trlog['max_acc'],
                    self.trlog['max_acc_interval']))
            f.write('Test acc={:.4f} + {:.4f}\n'.format(
                    self.trlog['test_acc'],
                    self.trlog['test_acc_interval']))            
